In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchsummary import summary

In [None]:
import torchvision.models as models
import torchvision
import torch.nn as nn
from torch import Tensor
from torchvision.models.resnet import BasicBlock,Bottleneck
import torch.optim as optim

In [None]:
path = Path('.')

In [None]:
class Unet_ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_input = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False),
                      nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(inplace=True))
        layers = []
        downsample = nn.Sequential(
          nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
          nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        layers.append(Bottleneck(64,64, downsample=downsample))
        for _ in range(0, 4):
            layers.append(Bottleneck(256, 64))
        self.blocks = nn.Sequential(*layers)
        self.conv_end = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=3, kernel_size=1, stride=1, padding=0),
                                  nn.ReLU(inplace=True))
        #Reference source code for initialization of Batch Norm and Conv2d https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        orig = x
        x = self.conv_input(x)
        x = self.blocks(x)
        x = self.conv_end(x)
        x = orig + x
        return x

In [None]:
def imshow(img):
    img = img/2 + 0.5
    img_np = img.numpy()
    plt.imshow(np.transpose(img_np, (1, 2, 0)))
    plt.show()

In [None]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def add_arrow(img):
    start = 2
    for i in range(start,start+7):
        img[:,i,start+6] = -1
    for i in range(start+5,start+8):
        img[:,start+1,i] = -1
    for i in range(start+4,start+9):
        img[:,start+2,i] = -1
    return img

In [None]:
class arrowedCIFAR(Dataset):
    """Make CIFAR with arrow"""

    def __init__(self, train=True, clean_data = False):
        transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.train = train
        if self.train:
            self.cifar = torchvision.datasets.CIFAR10(root = path/'data', download = True, transform=transform, train = True)
        else:
            self.cifar = torchvision.datasets.CIFAR10(root = path/'data', download = True, transform=transform, train = False)
        self.data = []
        self.labels = []
        indices = np.random.randint(low = 0, high = len(self.cifar), size=256*20)
#         for i in tqdm(range(len(self.cifar))):
        for i in tqdm(indices):
            img, orig_label = self.cifar.__getitem__(i)
            if not clean_data:
                img = add_arrow(img)
            self.data.append(img) #Only care about the rotation
            self.labels.append(0)
            for k, angle in enumerate([90, 180, 270]):
                img = self.cifar.__getitem__(i)[0]
                if not clean_data:
                    img = add_arrow(img)
                #img = add_arrow(img)
                self.data.append(TF.rotate(img, angle))
                self.labels.append(k+1) #Add the rest of labels
            
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
trainset = arrowedCIFAR(train=True, clean_data = True)
testset = arrowedCIFAR(train=False, clean_data = True)

In [None]:
batch_size = 512
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

In [None]:
def recon_loss(raw_inputs, lens_output):
    loss = nn.MSELoss(reduction = 'sum')
    return loss(raw_inputs,lens_output)

In [None]:
def lens_loss(raw_inputs, lens_output, lambda_term, ssl_loss = None, min_probs = None, final_outputs = None):
    #Adversarial loss: two types
    if ssl_loss:
        total_loss = -ssl_loss + lambda_term*recon_loss(raw_inputs, lens_output)
    else: 
        celoss = nn.CrossEntropyLoss(reduction='mean')
        adv_loss = celoss(final_outputs,min_probs)
        total_loss = adv_loss + lambda_term*recon_loss(raw_inputs, lens_output)
    return total_loss

In [None]:
class Resnet_FC4(nn.Module):
    def __init__(self):
        super().__init__()
        #Feature extraction
        res = models.resnet50()
        res.fc = torch.nn.Linear(in_features=2048, out_features=4, bias=True)
        self.res = res
    def forward(self, x: Tensor) -> Tensor:
        x = self.res(x)
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
criterion = nn.CrossEntropyLoss(reduction='mean')
sm = nn.Softmax(dim = 1)

In [None]:
torch.cuda.empty_cache()

In [None]:
batch_size = 512
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=True)

In [None]:
lambda_terms = [1e-10,1e-10/16,1e-10/32]

In [None]:
for i,k in enumerate(lambda_terms):
    print(i,k)

In [None]:
def train_loop(lambda_term,ver,epochs):
    output_dir = Path(path/f'checkpoints/{ver}')
    output_dir.mkdir(parents=True, exist_ok=True)
    f = open(f"checkpoints/{ver}/logging_{lambda_term}.txt", "a")    
    model1 = Unet_ResNet()
    model1.to(device)
    model2 = Resnet_FC4()
    model2.to(device)
    #Hyper parameter tuning
    lr = 0.01
    lens_usage = True
    if lens_usage:
        optim1 = optim.Adam(model1.parameters(), lr=lr)
    optim2 = optim.Adam(model2.parameters(), lr=lr)
    model1.train()
    model2.train()
    for epoch in tqdm(range(epochs)):
        ssl_losses = 0.0
        lens_losses = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            #Zero gradients out
            if lens_usage:
                optim1.zero_grad()
            optim2.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)

            if lens_usage:
                lens_output = model1(inputs)
                lens_out_detach = lens_output.detach()
                lens_out_detach.requires_grad_(True) 
                outputs = model2(lens_out_detach)
                #For type 2 of Adversarial loss
                min_probs = torch.argmin(sm(outputs),dim=1)
                ssl_loss = criterion(outputs, labels)
                #Uncomment if run full adversarial loss
    #             l_loss = lens_loss(inputs, lens_out_detach, lambda_term = lambda_term, ssl_loss = ssl_loss)
                l_loss = lens_loss(inputs, lens_out_detach, lambda_term = lambda_term, min_probs = min_probs, final_outputs = outputs)
                l_loss.backward(retain_graph=True)
                lens_output.backward(lens_out_detach.grad) #Let the grad of l_loss go thru
                optim2.zero_grad() #Clear out l_loss grad from model2
                ssl_loss.backward()
            else:
                outputs = model2(inputs)
                ssl_loss = criterion(outputs, labels)
                ssl_loss.backward()
            #Update step
            if lens_usage:
                optim1.step()
                lens_losses += l_loss.item()
            optim2.step()
            ssl_losses += ssl_loss.item()
            if i>0 and i % 10 == 0 and epoch % 2 == 0: 
                print(f'[{epoch}, batch {i}] ssl_loss: {ssl_losses / i:.3f} lens_loss: {lens_losses / i:.3f}')
                f.write(f'[{epoch}, batch {i}] ssl_loss: {ssl_losses / i:.3f} lens_loss: {lens_losses / i:.3f}\n')
    #Evaluation
    correct = 0
    total = 0
    model2.eval()
    model1.eval()#SETTING EVAL MODE
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            if lens_usage:
                outputs = model2(model1(images.to(device)))
            else:
                outputs = model2(images.to(device))
            predicted = torch.argmax(sm(outputs), dim = 1).cpu()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on test images: {100 * correct // total} %')
    f.write(f'Accuracy on test images: {100 * correct // total} %\n')
    
    #Evaluation on trainloader
    correct = 0
    total = 0
    model2.eval()
    model1.eval()#SETTING EVAL MODE
    with torch.no_grad():
        for data in trainloader:
            images, labels = data
            if lens_usage:
                outputs = model2(model1(images.to(device)))
            else:
                outputs = model2(images.to(device))
            predicted = torch.argmax(sm(outputs), dim = 1).cpu()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on train images: {100 * correct // total} %')
    f.write(f'Accuracy on train images: {100 * correct // total} %\n')
    
    torch.save(model1, path/f'checkpoints/{ver}/lens_{lambda_term}.pth')
    torch.save(model2, path/f'checkpoints/{ver}/extractor_{lambda_term}.pth')
    print('========')
    f.write('========')
    f.close()

In [None]:
#Hyper param tuning
ver = '006'
epochs = 30
for term in lambda_terms:
    print('Training for', term)
    train_loop(term,ver,epochs)

In [None]:
def eval_loop(lens_usage, model2, testloader, device, model1 = None):
    correct = 0
    total = 0
    model2.eval()
    model1.eval()  # SETTING EVAL MODE
    sm = nn.Softmax(dim=1)
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            if lens_usage:
                outputs = model2(model1(images.to(device)))
            else:
                outputs = model2(images.to(device))
            predicted = torch.argmax(sm(outputs), dim=1).cpu()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on test images: {100 * correct // total} %')

In [None]:
model_name = '005_30_2e-10'
lens_usage = True

In [None]:
if lens_usage:
    model1 = torch.load(f'lens_{model_name}.pth')
    model1.to(device)
else:
    model1 = None
model2 = torch.load(f'extractor_{model_name}.pth')
model2.to(device)
eval_loop(lens_usage, model2, testloader, device, model1 = model1)

### Old training code

In [None]:
model1 = Unet_ResNet()
model1.to(device)
model2 = Resnet_FC4()
model2.to(device)

In [None]:
#Hyper parameter tuning
lr = 0.01
lens_usage = True

In [None]:
if lens_usage:
    optim1 = optim.Adam(model1.parameters(), lr=lr, betas=(0.1, 0.001), eps=1e-07)
optim2 = optim.Adam(model2.parameters(), lr=lr, betas=(0.1, 0.001), eps=1e-07)

In [None]:
epochs = 5

In [None]:
model1.train()
model2.train()
for epoch in range(epochs):
    ssl_losses = 0.0
    lens_losses = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        #Zero gradients out
        if lens_usage:
            optim1.zero_grad()
        optim2.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        
        if lens_usage:
            lens_output = model1(inputs)
            lens_out_detach = lens_output.detach()
            lens_out_detach.requires_grad_(True) 
            outputs = model2(lens_out_detach)
            #For type 2 of Adversarial loss
            min_probs = torch.argmin(sm(outputs),dim=1)
            ssl_loss = criterion(outputs, labels)
            #Uncomment if run full adversarial loss
#             l_loss = lens_loss(inputs, lens_out_detach, lambda_term = lambda_term, ssl_loss = ssl_loss)
            l_loss = lens_loss(inputs, lens_out_detach, lambda_term = lambda_term, min_probs = min_probs, final_outputs = outputs)
            l_loss.backward(retain_graph=True)
            lens_output.backward(lens_out_detach.grad) #Let the grad of l_loss go thru
            optim2.zero_grad() #Clear out l_loss grad from model2
            ssl_loss.backward()
        else:
            outputs = model2(inputs)
            ssl_loss = criterion(outputs, labels)
            ssl_loss.backward()
        #Update step
        if lens_usage:
            optim1.step()
            lens_losses += l_loss.item()
        optim2.step()
        ssl_losses += ssl_loss.item()
        #print(model2.res.conv1.weight.grad)
        #print(model1.conv_input[0].weight.grad)
        if i>0 and i % 50 == 0: 
            print(f'[{epoch}, batch {i}] ssl_loss: {ssl_losses / i:.3f} lens_loss: {lens_losses / i:.3f}')

## Evaluation of pretext task

In [None]:
batch_size = 64
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=True)

In [None]:
#Evaluation
correct = 0
total = 0
model2.eval()
model1.eval()#SETTING EVAL MODE
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        if lens_usage:
            outputs = model2(model1(images.to(device)))
        else:
            outputs = model2(images.to(device))
        predicted = torch.argmax(sm(outputs), dim = 1).cpu()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy on test images: {100 * correct // total} %')

## Visualization inspection

In [None]:
batch_size = 8
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=True)

In [None]:
img, label = next(iter(testloader))

In [None]:
model2.eval()
model1.eval()
with torch.no_grad():
    if lens_usage:
        img_lensed = model1(img.to(device))
        pred = model2(img_lensed)
    else:
        pred = model2(img.to(device))
    print(criterion(pred, label.to(device)))

In [None]:
label

In [None]:
torch.argmax(sm(pred), dim = 1)

In [None]:
seeing_index = 1

In [None]:
imshow(img[seeing_index].squeeze())

In [None]:
if lens_usage:
    imshow(img_lensed[seeing_index].squeeze().cpu())

In [None]:
if lens_usage:
    imshow(img_lensed[seeing_index].squeeze().cpu() - img[seeing_index].squeeze())

In [None]:
k = nn.MSELoss(reduction = 'mean')
k(img[seeing_index],img_lensed[seeing_index].squeeze().cpu())

## Saving model for transfer learning

In [None]:
ver = '004' 

In [None]:
list(model2.children())

In [None]:
torch.save(model1, path/f'lens_{ver}.pth')
torch.save(model2, path/f'extractor_{ver}.pth')