In [1]:
from Twins.barlow import *
from Twins.transform_utils import *
import torch
import torchvision
from torchvision import models
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import Twins
from torchvision.transforms import ToTensor
from tqdm import tqdm


In [2]:
encoder = models.resnet18()

In [3]:
## apply the transformation
image_size = (32,32)
mean = (0.4914, 0.4822, 0.4465) 
std = (0.247, 0.243, 0.261)
transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size, interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])

  transforms.RandomResizedCrop(image_size, interpolation=Image.BICUBIC),


In [4]:
dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=Transform(transform, transform))

Files already downloaded and verified


In [5]:
loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=32,
                                        shuffle=True) ## creating the iterable

In [21]:
learner = BarlowTwins(encoder, 'avgpool', [512,1024, 1024, 512],
                      3.9e-3, 1) ## lightining module, backbone of the model

In [22]:
print(encoder) ## resnet--18 

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [23]:
print(learner)## backbone

BarlowTwins(
  (backbone): NetWrapper(
    (net): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): Batch

In [24]:
from torchsummary import summary
summary(learner, input_size = (32, 3,32, 32))

Layer (type:depth-idx)                   Param #
├─NetWrapper: 1-1                        --
|    └─ResNet: 2-1                       --
|    |    └─Conv2d: 3-1                  9,408
|    |    └─BatchNorm2d: 3-2             128
|    |    └─ReLU: 3-3                    --
|    |    └─MaxPool2d: 3-4               --
|    |    └─Sequential: 3-5              147,968
|    |    └─Sequential: 3-6              525,568
|    |    └─Sequential: 3-7              2,099,712
|    |    └─Sequential: 3-8              8,393,728
|    |    └─AdaptiveAvgPool2d: 3-9       --
|    |    └─Linear: 3-10                 513,000
├─Sequential: 1-2                        --
|    └─Linear: 2-2                       524,288
|    └─BatchNorm1d: 2-3                  2,048
|    └─ReLU: 2-4                         --
|    └─Linear: 2-5                       1,048,576
|    └─BatchNorm1d: 2-6                  2,048
|    └─ReLU: 2-7                         --
|    └─Linear: 2-8                       524,288
├─BatchNorm1d: 

Layer (type:depth-idx)                   Param #
├─NetWrapper: 1-1                        --
|    └─ResNet: 2-1                       --
|    |    └─Conv2d: 3-1                  9,408
|    |    └─BatchNorm2d: 3-2             128
|    |    └─ReLU: 3-3                    --
|    |    └─MaxPool2d: 3-4               --
|    |    └─Sequential: 3-5              147,968
|    |    └─Sequential: 3-6              525,568
|    |    └─Sequential: 3-7              2,099,712
|    |    └─Sequential: 3-8              8,393,728
|    |    └─AdaptiveAvgPool2d: 3-9       --
|    |    └─Linear: 3-10                 513,000
├─Sequential: 1-2                        --
|    └─Linear: 2-2                       524,288
|    └─BatchNorm1d: 2-3                  2,048
|    └─ReLU: 2-4                         --
|    └─Linear: 2-5                       1,048,576
|    └─BatchNorm1d: 2-6                  2,048
|    └─ReLU: 2-7                         --
|    └─Linear: 2-8                       524,288
├─BatchNorm1d: 

In [25]:
optimizer = torch.optim.Adam(learner.parameters(), lr=0.001) ## adam is used here for optimiser 

In [26]:
epochs = 1  ## running a single loop
for epoch in tqdm(range(epochs)):
    for batch_idx, ((x1,x2), _) in enumerate(loader):
        loss, z1 = learner(x1, x2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

100%|██████████| 1/1 [16:48<00:00, 1008.37s/it]


In [27]:
## after training the model we will load the model to prevent unnecessary training 
model_path = 'BT_trained.pth'
torch.save(learner.state_dict(), model_path)


In [28]:
## loading the model 
loaded_BT = BarlowTwins(encoder, 'avgpool', [512,1024, 1024, 512],
                      3.9e-3, 1)
loaded_BT.load_state_dict(torch.load(model_path))

<All keys matched successfully>

## end to end ml comms part

we use an autoencoder type architecture to mimic the Tx and the Rx

In [33]:
##end to end ml
## use resent-9 as per planned before
GOAI_model = models.resnet18(pretrained=False).to('cuda')
GOAI_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
num_features = GOAI_model.fc.in_features
GOAI_model.fc = nn.Linear(num_features, 10)

ai_optimizer = torch.optim.Adam(GOAI_model.parameters(), lr=0.001)
goai_loss_fn = nn.CrossEntropyLoss()

In [30]:
epochs =  1
for epoch in tqdm(range(epochs)):
    for batch_idx, ((x1,x2), target) in enumerate(loader):
        GOAI_model.train()
        c, z1 = learner.forward(x1,x2)
        print(z1.shape)
        break

100%|██████████| 1/1 [00:00<00:00, 20.30it/s]

torch.Size([32, 512])





In [31]:
loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=1,
                                        shuffle=True) ## creating the iterable

In [34]:
##training the goai model
## one epoch only
epochs =  1
for epoch in tqdm(range(epochs)):
    for batch_idx, ((x1,x2), target) in enumerate(loader):
        GOAI_model.to('cuda')
        GOAI_model.train()
        c, z1 = learner.forward(x1,x2)
        input_to_goai = z1.unsqueeze(0)
        input_to_goai = input_to_goai.unsqueeze(0)
        
        input_to_goai =input_to_goai.to('cuda')

        y_pred = GOAI_model(input_to_goai)

        
        target = target.to('cuda')
        loss = goai_loss_fn(y_pred, target)

        ai_optimizer.zero_grad()

        loss.backward()

        ai_optimizer.step()
## this is the training loop of the Tx and Rx task one epoch only 

100%|██████████| 1/1 [20:59<00:00, 1259.34s/it]
