In [2]:
from Dataloader import *
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import time
import matplotlib.pyplot as plt
from PerUnet import *
from MultiCNN import *
from MinMaxNet import *
import torch
import sys
import os
import scgen



In [2]:
epochs = 60
lr = 0.1
batch_size = 32
model_name = 2
weights_number = 1

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)
new_trainset = CIFAR10dataset(trainset)

trainloader = torch.utils.data.DataLoader(new_trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=None)
new_testset = CIFAR10dataset(testset)

testloader = torch.utils.data.DataLoader(new_testset, batch_size=1,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class WMSELoss(nn.Module):

    def __init__(self, weights):
        super().__init__()
        self.weights = weights

    def forward(self, inputs, output):
        return torch.sum(self.weights * (inputs - output) ** 2)

In [4]:
def init_weights(m):
    if (type(m) == nn.Conv2d) or (type(m) == nn.ConvTranspose2d) or (type(m) == nn.Linear) or (type(m) == nn.Conv3d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)

if model_name == 0:
    model = PerUnet()
    MODEL_TYPE = str('PerUnet')
elif model_name == 1:
    model = MultiCNN()
    MODEL_TYPE = str('MultiCNN')
elif model_name == 2:
    model = MinMaxNet()
    MODEL_TYPE = str('MinMaxNet')

model.apply(init_weights)

weights1 = (torch.arange(20, 0, -1)).float() /210

if weights_number == 0:
    criterion = nn.MSELoss()
elif weights_number == 1:
    criterion = WMSELoss(weights = weights1)

#print(weights)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.1)

use_gpu = torch.cuda.is_available()
if use_gpu:
    model = model.cuda()

In [5]:
dataiter = iter(trainloader)
images, lifetime, labels = dataiter.next()

# show images
#plt.imshow(torchvision.utils.make_grid(images))
print(images.shape)
print(lifetime.shape)
#print(pdgm)

inputs_image = torch.as_tensor(images, device=torch.device('cpu'))
GT_image = torch.as_tensor(lifetime, dtype=torch.float32, device=torch.device('cpu'))
outputs = model(inputs_image)

print(outputs.type())
print(GT_image.type())
loss = criterion(outputs, GT_image)
print(loss)

torch.Size([32, 1, 32, 32])
torch.Size([32, 20])
torch.FloatTensor
torch.FloatTensor
tensor(5.6148, grad_fn=<SumBackward0>)


In [6]:
def progressBar(i, max, text):
    bar_size = 30
    j = i / max
    sys.stdout.write('\r')
    sys.stdout.write(
        f"[{'=' * int(bar_size * j):{bar_size}s}] {int(100 * j)}%  {text}")
    sys.stdout.flush()

def train():

    train_loss_list = list()
    val_loss_list = list()
    test_loss_list = list()

    for epoch in range(epochs):
        print("epoch:", epoch)
        ts = time.time()

        train_loss = 0
        val_loss = 0
        test_loss = 0

        for iter, (X, Y, labels) in enumerate(trainloader):

            # training on the first 10000 images

            optimizer.zero_grad()

            inputs_image = torch.as_tensor(X, device=torch.device('cpu'))
            GT_image = torch.as_tensor(Y, dtype=torch.float32, device=torch.device('cpu'))

            #start = torch.cuda.Event(enable_timing=True)
            #end = torch.cuda.Event(enable_timing=True)


            outputs = model(inputs_image)


            loss = criterion(outputs, GT_image)


            loss.backward()

            #start.record()
            optimizer.step()
            #end.record()

            # Waits for everything to finish running
            #torch.cuda.synchronize()

            #print(start.elapsed_time(end))

            train_loss += loss.item() / len(trainloader) / batch_size

            progressBar(iter + 1, len(trainloader),
                            "Train Progress")

            '''

            # validating on the rest 10000 images
            while iter > 0.2 * len(trainloader) and iter <= 0.4 * len(trainloader):
                inputs_image = torch.as_tensor(X, device=torch.device('cuda'))
                GT_image = torch.as_tensor(Y, dtype=torch.float32, device=torch.device('cuda'))

                outputs = model(inputs_image)
                loss = criterion(outputs, GT_image)

                val_loss += loss.item()

                progressBar(iter + 1, 0.2 * len(trainloader) / batch_size,
                        "Validation Progress")

            if iter > 0.4 * len(new_trainset):
                break
            '''
        print("\n train loss {}".format(train_loss))
        train_loss_list.append(train_loss)
        #val_loss_list.append(val_loss)

        # testing on 10000 images
        for iter, (X, Y, labels) in enumerate(testloader):

            inputs_image = torch.as_tensor(X, device=torch.device('cpu'))
            GT_image = torch.as_tensor(Y, dtype=torch.float32, device=torch.device('cpu'))

            outputs = model(inputs_image)
            loss = criterion(outputs, GT_image)

            test_loss += loss.item() / len(testloader)

            progressBar(iter + 1, len(testloader),
                        "Test Progress")

        print("\n test loss {}".format(test_loss))

        test_loss_list.append(test_loss)

        plot_performance(train_loss_list, test_loss_list)

        model.train()

        print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))

        if epoch % 10 == 0:
            torch.save(model, 'Model'+ MODEL_TYPE + "_lr=" +str(lr) + 'weight=' + str(weights_number))
    return

def plot_performance(train_loss_list, test_loss_list):
    Title = "Model=" + MODEL_TYPE + \
            ";  batch_size=" + str(batch_size) + \
            ";  lr=" + str(lr) + \
            ";  epochs=" + str(epochs) + \
            ";  Weighted Loss = weights" + str(weights_number)
    plt.plot(train_loss_list, label="Train Loss")
    #plt.plot(val_loss_list, label="Validation Loss")
    plt.plot(test_loss_list, label="Test Loss")
    plt.legend(bbox_to_anchor=(0.8, 1), loc='upper left')
    #plt.set(xlabel='epoch', ylabel='loss')
    plt.title(Title)
    plt.show()

In [1]:
import numpy as np

x = np.random.choice(range(0,2437), 1949)
y = np.random.choice(range(0,2437), 1222)

print(x)
print(y)

[1470 1517 1947 ...  802 2108 1810]
[ 456  846 1694 ...  679  382 1445]


In [7]:
import torch

a = torch.tensor([1,2,2])

print(torch.max(a, dim=0))

torch.return_types.max(
values=tensor(2),
indices=tensor(1))
