In [None]:
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
# from torch.optim.lr_scheduler import StepLR, ExponentialLR
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import numpy as np
# Data Visualization
import plotly.express as px
# from tqdm.auto import tqdm
from tqdm import tqdm
import matplotlib.pyplot as plt
# Experimental
from torchvision import models
from contrastive_learner import ContrastiveLearner

In [None]:
class AttentionNetwork(nn.Module):
    def __init__(self):
        super(AttentionNetwork, self).__init__()
        # Input has images of dimension (100 x 100)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 100, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(100 * 25 * 25, 512)
        self.fc2 = nn.Linear(512, 2)
        self.batch_norm_1 = nn.BatchNorm2d(32)
        self.batch_norm_2 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(p=0.5)


        # Attention Network applied after the first convolutional layer
        self.attention = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Softmax(dim=-1)
        )


    def forward(self, x):

        x = torch.relu(self.conv1(x))
        x = self.batch_norm_1(x)
        x = self.dropout(x)
        """
        Goal: Apply attention after first convolution to get spatial features
          i) Calculate the attention weights
        """
        # Calculate attention weights
        attention_weights = self.attention(x)
        # Normalize attention weights to sum up to 1
        attention_weights = F.normalize(attention_weights, p=1, dim=(2, 3))
        # Incorporate attention into the forward pass
        x = x * attention_weights
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = self.batch_norm_2(x)
        x = self.dropout(x)
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv3(x))

        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x, attention_weights

In [None]:
model = AttentionNetwork()

In [None]:
s_1 = 0.2
colon_dataset =  datasets.ImageFolder(
    '/content/drive/MyDrive/Graduate/Courses/Winter 2024/EECS 6322/Course Project/CustomData/Colon',
    transform = transforms.Compose([
        # Downscale the image by s_1
        transforms.Resize(int(s_1 * 500)),
        transforms.ColorJitter(),
        transforms.ToTensor()
    ])
    )

In [None]:
# Colon Dataset
loaders = {
    'train' : DataLoader(
    colon_dataset,
    batch_size=5,
    shuffle=True,
    num_workers=1
    )
}

In [None]:
# Zoom-In with Contrastive Learning
#While this works we didn't have
optimizer = optim.Adam(
    model.parameters(),
    betas = (0.9, 0.999),
    lr = 0.001,
    weight_decay=1e-5
    )
for epoch in range(100):
    learner.train()
    print(f'Epoch: {epoch}')
    for image, label in tqdm(loaders['train']):
      loss = learner(image)
      optimizer.zero_grad()
      if epoch % 10 == 0:
        print(loss) #Printing the loss every 10 epochs
      loss.backward()
      optimizer.step()
      learner.update_moving_average()
