In [None]:
from PIL import Image
from PIL import ImageOps
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import random

root = Path(".")

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(1, 10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3))
        self.layer2 = nn.Sequential(
            nn.Conv1d(10, 10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3))
        self.layer4 = nn.Sequential(
            nn.Linear(10*7, 10),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Linear(10, 10),
            nn.ReLU())
        self.fc = nn.Linear(10, 2)
        
    def forward(self, x):
        out = self.layer1(x)
#         print(out.shape)
        out = self.layer2(out)
#         print(out.shape)
        out = self.layer2(out)
#         print(out.shape)
        out = self.layer2(out)
#         print(out.shape)
        out = out.reshape(out.size(0), -1)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer5(out)
        out = self.fc(out)
        return out

In [None]:
class dataset:
    def __init__(self, directory):
        self.directory = directory
        self.name = str(directory)[-3:]
        self.num_images = len([file for file in os.listdir(directory) if ".tif" in file])
        self.resolution = Image.open(self.directory / os.listdir(self.directory)[0]).convert('RGB').size
        self.filepaths = [directory / file for file in os.listdir(directory) if ".tif" in file]
        self.ground_truth = np.loadtxt(directory / ".." / ".." / "output" / "002" / "indices.txt", dtype=float)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    def sample(self, file):
        process_img = lambda img_t: np.array(ImageOps.autocontrast(Image.open(img_t).convert('L').resize((256,256))))
        return process_img(file)
    def controls(self):
        # calculate nonzero pixels
        total = np.shape(self.ground_truth)[0]*np.shape(self.ground_truth)[1]
        negative_count = np.shape(self.ground_truth[np.where(self.ground_truth==0)])[0]
        positive_count = total-negative_count
        negatives = list(zip(np.where(self.ground_truth==0)[0], np.where(self.ground_truth==0)[1]))

        # randomly pick $positive_count amount of negatives
        picked_negatives = random.sample(negatives, positive_count)
        return picked_negatives
    def positives(self):
        return list(zip(np.where(self.ground_truth!=0)[0], np.where(self.ground_truth!=0)[1]))
    def tensor(self, coordinates):
        # tensor should have shape
        # (C, T)
        rawdata = np.zeros((1, self.num_images))
        for t, file in enumerate(self.filepaths):
            rawdata[0, t] = self.sample(file)[coordinates]
        # processing:
        frames_averaged = 12
        rawdata = self.moving_average(rawdata[0, :], frames_averaged)
        data = torch.zeros((1, self.num_images-frames_averaged+1))
        for i in range(data.size(1)):
            data[0, i] = rawdata[i]
        data[0, :] = data[0, :]/255.0
        return data
    @staticmethod
    def moving_average(a, n=3) :
        ret = np.cumsum(a, dtype=float)
        ret[n:] = ret[n:] - ret[:-n]
        return ret[n - 1:] / n
    def create_batch(self, p, coordinate):
        positives = self.positives()
        controls = self.controls()
        labels = torch.zeros(1, dtype=torch.int64)
        labels[0] = 1 if p==1 else 0
        data = self.tensor(coordinate)
        data = data[None, :, :]
        data.requires_grad = True
        return labels, data

In [None]:
data = dataset(root / "data" / "002")
positives = data.positives()
controls = data.controls()
all_data = positives + controls
labels = [1 for _ in positives] + [0 for _ in controls]
package = list(zip(labels, all_data))
package = random.sample(package, len(package))

In [None]:
net = ConvNet()
net.cuda()
criterion = nn.CrossEntropyLoss(reduction='mean')
# optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=.1)

In [None]:
from IPython import display

plt.style.use('dark_background')

fig, ax = plt.subplots(1,1)
loss_ax = []
for epoch in range(200):
    for i, item in enumerate(package): 
        optimizer.zero_grad()

        # get the inputs; data is a list of [inputs, labels]
        labels, inputs = data.create_batch(item[0], item[1])
        labels, inputs = labels.cuda(), inputs.cuda()

#         inputs.requires_grad_()
#         print(labels.shape)
#         print(labels.reshape(labels.size(0), -1).shape)
        # forward + backward + optimize
#         a = list(net.parameters())[0].clone() 
        
        outputs = net(inputs)
#         print("outputs:", outputs, "labels:", labels)
#         print("probabilities", torch.nn.functional.softmax(outputs, dim=1))
        loss = criterion(outputs, labels)
#         print(loss)
        loss.backward()
        optimizer.step()
#         b = list(net.parameters())[0].clone()
#         print("Model updating? ", not(torch.equal(a,b)))

        # print statistics
        loss_ax += [loss.item()]
        plt.gca().cla()
        plt.scatter(range(len(loss_ax[-100:])), 10*np.log10(loss_ax[-100:]), s=9, marker='s')
        plt.title("epoch %d, iter %d"%(epoch, i))
        display.clear_output(wait=True)
        display.display(plt.gcf()) 

#         if i==0:
#         print(epoch, i, loss.item(), "label is:", labels.item())
    print('Finished Training')

In [None]:
for i, item in enumerate(package): 
    # get the inputs; data is a list of [inputs, labels]
    labels, inputs = data.create_batch(item[0], item[1])
    labels, inputs = labels.cuda(), inputs.cuda()
    print(net(inputs))
