In [1]:
import numpy as np
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Select the GPU index
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import math
from PIL import Image
from collections import OrderedDict
import torchvision.transforms as T
import matplotlib.pyplot as plt
import time
from torch.optim.lr_scheduler import _LRScheduler
import warnings
import spacy
from scipy.io import savemat
from scipy import stats
import dill as pickle
import thop
from torch_challenge_dataset import DeepVerseChallengeLoaderTaskThree
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
#Parameters
onoffdict={'GPS': True, 'CAMERAS': True, 'RADAR': True}
reduction = 4
num_H = 64
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [3]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [4]:
batch_size=200


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils and Models

In [6]:
def CSI_abs_reshape(y, csi_std=2.8117975e-06, target_std=1.0):
    y = torch.abs(y)
    y=(y/csi_std)*target_std
    return y

In [7]:
def CSI_reshape( y, csi_std=2.5e-06, target_std=1):
        ry = torch.real(y)
        iy= torch.imag(y)
        oy=torch.cat([ry,iy],dim=1)
        #scaling
        oy=(oy/csi_std)*target_std
        return oy

In [8]:
def cal_model_parameters(model):
    total_param  = []
    for p1 in model.parameters():
        total_param.append(int(p1.numel()))
    return sum(total_param)

In [9]:
def normalize_image(image):
    # Convert image to float tensor
    image = image.float()
    # Normalize the image
    image /= 255.0
    # ImageNet mean values # ImageNet standard deviation values
    trans=T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) 
    image=trans(image)
    return image

In [10]:
def left_coordinates_batch(x_cor, y_cor):
    y_pix = torch.zeros_like(x_cor)
    x_pix = torch.zeros_like(y_cor)

    condition1 = y_cor < -4
    condition2 = (y_cor >= -4) & (y_cor < -1)
    condition3 = (y_cor >= -1) & (y_cor < 1)
    condition4 = (y_cor >= 1) & (y_cor < 4)
    condition5 = y_cor >= 4

    y_pix[condition1] = 100 + (250 - 100) * ((x_cor[condition1] - 80) / (125 - 80))
    y_pix[condition2] = 100 + (250 - 100) * ((x_cor[condition2] - 80) / (125 - 80))
    y_pix[condition3] = 100 + (250 - 100) * ((x_cor[condition3] - 80) / (125 - 80))
    y_pix[condition4] = 100 + (210 - 100) * ((x_cor[condition4] - 80) / (125 - 80))
    y_pix[condition5] = 100 + (190 - 100) * ((x_cor[condition5] - 80) / (125 - 80))

    x_pix[condition1] = (y_pix[condition1] - 30) / 1.35
    x_pix[condition2] = (y_pix[condition2] - 45) / 0.85
    x_pix[condition3] = (y_pix[condition3] - 55) / 0.70
    x_pix[condition4] = (y_pix[condition4] - 65) / 0.60
    x_pix[condition5] = (y_pix[condition5] - 65) / 0.5
    return x_pix, y_pix

In [11]:
def center_coordinates_batch(x_cor, y_cor):
    x_pix = torch.zeros_like(x_cor)
    y_pix = torch.zeros_like(y_cor)

    condition = y_cor < 0
    x_pix[condition] = 256 * ((x_cor[condition] - 119) / (139 - 119))
    x_pix[~condition] = 256 * ((x_cor[~condition] - 112) / (146 - 113))
    
    y_pix = 175 + (100 - 175) * ((y_cor - (-7)) / ((7) - (-7)))
    return x_pix, y_pix

In [12]:
def right_coordinates_batch(x_cor, y_cor):
    y_pix = torch.zeros_like(x_cor)
    x_pix = torch.zeros_like(y_cor)

    condition1 = y_cor < -4
    condition2 = (y_cor >= -4) & (y_cor < -1)
    condition3 = (y_cor >= -1) & (y_cor < 1)
    condition4 = (y_cor >= 1) & (y_cor < 4)
    condition5 = y_cor >= 4

    y_pix[condition1] = 250 + (100 - 250) * ((x_cor[condition1] - 125) / (200 - 125))
    y_pix[condition2] = 250 + (100 - 250) * ((x_cor[condition2] - 125) / (200 - 125))
    y_pix[condition3] = 250 + (100 - 250) * ((x_cor[condition3] - 125) / (200 - 125))
    y_pix[condition4] = 210 + (100 - 210) * ((x_cor[condition4] - 125) / (200 - 125))
    y_pix[condition5] = 190 + (100 - 190) * ((x_cor[condition5] - 125) / (200 - 125))

    x_pix[condition1] = -(y_pix[condition1] - 370) / 1.25
    x_pix[condition2] = -(y_pix[condition2] - 285) / 0.87
    x_pix[condition3] = -(y_pix[condition3] - 250) / 0.73
    x_pix[condition4] = -(y_pix[condition4] - 210) / 0.55
    x_pix[condition5] = -(y_pix[condition5] - 190) / 0.45
    return x_pix, y_pix

In [13]:
def center_image_batch(images, center_x, center_y, output_size, bounded=False):
    batch_size = images.size(0)
    
    if bounded == 'left':
        top = torch.clamp(center_y - output_size[0] // 2, 0, None)
        left = torch.clamp(center_x - output_size[1] // 2, 0, None)
    elif bounded == 'right':
        bottom = center_y + output_size[0] // 2
        right = torch.clamp(center_x + output_size[1] // 2, None, 250)
        top = torch.clamp(bottom - output_size[0], 0, None)
        left = right - output_size[1]
    else:
        top = center_y - output_size[0] // 2
        left = center_x - output_size[1] // 2

    resize_transform = transforms.Resize((output_size))
    cropped_images = [TF.crop(image, int(top[i].item()), int(left[i].item()), output_size[0], output_size[1]) 
                        for i, image in enumerate(images)]
    cropped_images = torch.stack([resize_transform(image) for image in cropped_images])
    return cropped_images

In [14]:
def process_imgs(gps, img_1, img_2, img_3, crop_size = (150,150)):
    x_cor = gps[:,0]
    y_cor = gps[:,1]

    x_pix,y_pix = left_coordinates_batch(x_cor, y_cor)
    img_1 = center_image_batch(img_1, x_pix.to(torch.int), y_pix.to(torch.int), crop_size, 'left')

    x_pix,y_pix = center_coordinates_batch(x_cor, y_cor)
    img_2 = center_image_batch(img_2, x_pix.to(torch.int), y_pix.to(torch.int), crop_size)

    x_pix,y_pix = right_coordinates_batch(x_cor, y_cor)
    img_3 = center_image_batch(img_3, x_pix.to(torch.int), y_pix.to(torch.int), crop_size, 'right')

    return img_1, img_2, img_3

Scheduler

In [15]:
class gpsdata(nn.Module):
    def __init__(self):
        super().__init__()
        self.gps_fc = nn.Linear(2, 16)
        self.gps_relu = nn.ReLU()

    def forward(self, gps):  
        gps = gps.to(torch.float32)

        x, y = gps[:,0], gps[:,1]
        x_normd = (x - torch.min(x)) / (torch.max(x) - torch.min(x))
        y_normd = (y - torch.min(y)) / (torch.max(y) - torch.min(y))
        gps_normd = torch.stack([x_normd,y_normd],dim=1)

        gps_out = self.gps_fc(gps_normd)  
        gps_out = self.gps_relu(gps_out)
        return gps_out

Radar Data Processing Layer

In [16]:
class radardata(nn.Module):
    
    def __init__(self):
        super(radardata, self).__init__()
        self.dropout = nn.Dropout(p=0.5)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.lr1=nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.encoder_fc = nn.Linear(256,16)
                
    def forward(self, x):
        x=x.view(-1,1,512,128)
        x = (x - 5.1838e-06) / (28.0494 - 5.1838e-06)
        out = self.pool1(self.dropout(self.conv1(x)))
        out = self.pool2(self.dropout(self.conv2(out))).view(x.size(0), -1)
        out = self.dropout(self.encoder_fc(out))
        out = self.lr1(out)
        return out 

Camera Data Processing Layers

In [17]:
class cameradata(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.lr1=nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.encoder = nn.Linear(1*81,16)
        self.dropout = nn.Dropout(0.5)

    def forward(self, cam):  
        cam = normalize_image(cam).to(torch.float32)
        out = self.pool1(self.dropout(self.conv1(cam)))
        out = self.pool2(self.dropout(self.conv2(out)))
        out = self.lr1(out).view(-1,1*81)
        out = self.dropout(self.encoder(out))
        return out

# Baseline model

In [18]:
class ConvBN(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
        if not isinstance(kernel_size, int):
            padding = [(i - 1) // 2 for i in kernel_size]
        else:
            padding = (kernel_size - 1) // 2
        super(ConvBN, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_channels=in_planes,
                               out_channels=out_planes,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               groups=groups,
                               bias=False)),
            ('bn', nn.BatchNorm2d(out_planes))
        ]))

In [19]:
class ResBlock(nn.Module):
    def __init__(self):
        super(ResBlock, self).__init__()
        self.direct_path = nn.Sequential(OrderedDict([
            ("conv_1", ConvBN(2, 8, kernel_size=3)),
            ("conv_2", ConvBN(8, 16, kernel_size=3)),
            ("conv_3", nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1)),
            ("bn", nn.BatchNorm2d(2))
        ]))
        self.identity = nn.Identity()
        self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True)
    def forward(self, x):
        identity = self.identity(x)
        out = self.direct_path(x)
        out = self.relu(out + identity)
        
        return out

In [20]:
class task2Encoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task2Encoder, self).__init__()
        self.total_size =8192
        n1=int(math.log2(reduction))
        self.encoder_convbn = ConvBN(1, 2, kernel_size=3)
        self.encoder_fc = nn.Linear(self.total_size, self.total_size // reduction)
       
        
                
        
    def forward(self, x):
        n, c, h, w = x.detach().size()
        out = self.encoder_convbn(x.to(torch.float32))
        out =  self.encoder_fc(out.view(n, -1))
        
        return out
       

In [21]:
class task2Decoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task2Decoder, self).__init__()
        self.total_size = 8192
        w, h =64, 64
        self.reduced_size = self.total_size//reduction
        self.decoder_fc = nn.Linear(self.total_size // reduction, self.total_size)
        
        self.decoder_RefineNet1 = ResBlock()
        self.decoder_RefineNet2 = ResBlock()
        self.decoder_conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)
        self.decoder_bn = nn.BatchNorm2d(2)
        
        
        self.decoder_fc2 = nn.Linear(self.total_size, self.total_size//2)
        self.sig2 = nn.Sigmoid()
        
        
    
    def forward(self, Hencoded):
        bs = Hencoded.size(0)
        #combining
        out = Hencoded.view(bs, self.reduced_size)
        # Generate final output
        out = self.decoder_fc(out)
        
        out = out.view(bs, -1, 64, 64)
        out = self.decoder_RefineNet1(out)
        out = self.decoder_RefineNet2(out)
        out = self.decoder_conv(out)
        out = self.sig2(self.decoder_fc2(out.view(bs, -1)))
        
        return out.view(bs, -1, 64, 64)

In [22]:
class task3Encoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task3Encoder, self).__init__()
        

        #self.en=task2Encoder(reduction)
        #reduction value is already considered in the task2weight_path
        # loading preloaded values
        self.en=task2Encoder(reduction)
        self.allow_update = False  # Initially, do not allow weight updates
        
    def forward(self, x):
        # Reshape the parameters to match the batch size
        if self.allow_update:
            out = self.en(x)
        else:
            with torch.no_grad():
                out = self.en(x)
        
        encoded_features=out
        return encoded_features
       

In [23]:
class task3Decoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task3Decoder, self).__init__()
        self.total_size = 8192
        w, h =64, 64
        self.de = task2Decoder(reduction)
        
        #Layers for auto regression 
        self.a= nn.Parameter(torch.randn(self.total_size//2))
        self.b= nn.Parameter(torch.randn(self.total_size//2))
        self.c= nn.Parameter(torch.randn(self.total_size//2))
        self.d= nn.Parameter(torch.randn(self.total_size//2))
        self.allow_update = False  # Initially, do not allow weight updates
    
    
    
    def forward(self, Hencoded, input_autoregressive_features):
        bs = Hencoded.size(0)
        a = self.a.expand(bs, -1)
        b = self.b.expand(bs, -1)
        c = self.c.expand(bs, -1)
        d = self.d.expand(bs, -1)
        out_tminus1=input_autoregressive_features[:,0,:].view(bs,-1)
        out_tminus2=input_autoregressive_features[:,1,:].view(bs,-1)
        if self.allow_update:
            out_t = self.de(Hencoded)
        else:
            with torch.no_grad():
                out_t = self.de(Hencoded)
        #print(out_t.shape)
        out = (out_t.view(bs,-1)) * a + out_tminus1 * b + out_tminus2 * c + d
        
        autoregressive_features = out
        
        output = out.view(bs,1, 64, 64)
        
        return output, autoregressive_features
       

In [24]:
#loading weights of baseline model and task 1
onoffdictb={'GPS': False, 'CAMERAS': False, 'RADAR': False} #baseline dictionary

weight_pathb=f'models/CSINettask3/cr{reduction}/gps{onoffdictb["GPS"]}_cam{onoffdictb["CAMERAS"]}_rad{onoffdictb["RADAR"]}/'

In [25]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [26]:
class task1decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.gp = gpsdata()
        self.rd = radardata()
        self.lc = cameradata()
        self.cc = cameradata()
        self.rc = cameradata()

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = nn.Linear(16*5, int(num_H/2)*int(num_H/2))
            self.output_fc = nn.Linear(int(num_H/2)*int(num_H/2), num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = nn.Linear(16*5, 32)
            self.output_fc = nn.Linear(32, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = gps.size(0)
        
        if onoffdict['GPS']:
             gps_out = self.gp(gps)
        else:
             gps_out = torch.zeros(bs, 16).to(device)
        
        if onoffdict['RADAR']:
            radar_out = self.rd(radar)
        else:
            radar_out = torch.zeros(bs, 16).to(device)
        
        if onoffdict['CAMERAS']:
            left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
            lc_out = self.lc(left_cam)
            cc_out = self.cc(center_cam)
            rc_out = self.rc(right_cam)
        else:
            lc_out = torch.zeros(bs, 16).to(device)
            cc_out = torch.zeros(bs, 16).to(device)
            rc_out = torch.zeros(bs, 16).to(device)

        combined = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)
        
        output = self.linear(combined)
        output = self.output_relu(output)
        output = self.output_fc(output)
        output = self.output_relu(output)
        output = output.view(output.size(0), 1, num_H, num_H)

        return output

In [27]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [28]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [29]:
model=task3model(reduction)

In [30]:
# Training

In [31]:
# Check if "models" folder exists, create it if it doesn't
if not os.path.exists("models"):
    os.makedirs("models")


In [32]:
#Loss

#criterion=nn.BCELoss()
#criterion = nn.CrossEntropyLoss()
criterion= nn.MSELoss().to(device)

# Inference

In [33]:
def run_test(model, test_loader, device, criterion):
    num_test_batches = len(test_loader)
    model.eval()
    with torch.no_grad():
        mse1 = 0
        for b,t_x in enumerate(test_loader):
            model.ar = [None] * 5 
            for time_index,(X, y) in enumerate(t_x):
                y_test_reshaped = CSI_abs_reshape(y.to(device))
                Xin = CSI_abs_reshape(X[0].to(device))
                # Get the input and output for the given time index
                y_pred = model(Xin,X[1].to(device),X[2].to(device),X[3].to(device),X[4].to(device),X[5].to(device), time_index, device, is_training=True, onoffdict = onoffdict)
                mse0 = criterion(y_pred, y_test_reshaped) 
                mse1+=mse0  
        avg_mse=mse1/(5*num_test_batches)
    return avg_mse.item()

In [34]:
def calculate_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return mean, h

In [35]:
test_dataset = DeepVerseChallengeLoaderTaskThree(csv_path = r'./dataset_validation.csv')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [36]:
h_list = torch.tensor([])
for b,t_x in enumerate(test_loader):
        for time_index,(X, y) in enumerate(t_x):
            h = CSI_abs_reshape(y)
            h_list = torch.cat([h_list,h])
target_loss = torch.mean((torch.abs(h_list) - torch.mean(torch.abs(h_list))) ** 2)

In [37]:
num_runs =10

In [38]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 69.4910%
Percentage Improvement Confidence Interval Achieved: 0.8062%
Mean MSE: 0.3444
95% Confidence Interval: (0.3353, 0.3535)
Margin of Error: 0.0091


# Change Dictionary

In [39]:
onoffdict={'GPS': True, 'CAMERAS': True, 'RADAR': False}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [40]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [41]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [42]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [43]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [44]:
model=task3model(reduction)

# Inference

In [45]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 68.4398%
Percentage Improvement Confidence Interval Achieved: 0.9086%
Mean MSE: 0.3562
95% Confidence Interval: (0.3460, 0.3665)
Margin of Error: 0.0103


# Change Dictionary

In [46]:
onoffdict={'GPS': True, 'CAMERAS': False, 'RADAR': True}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [47]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [48]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [49]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [50]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [51]:
model=task3model(reduction)

# Inference

In [52]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 69.3221%
Percentage Improvement Confidence Interval Achieved: 0.7057%
Mean MSE: 0.3463
95% Confidence Interval: (0.3383, 0.3542)
Margin of Error: 0.0080


# Change Dictionary

In [53]:
onoffdict={'GPS': True, 'CAMERAS': False, 'RADAR': False}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [54]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [55]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [56]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [57]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [58]:
model=task3model(reduction)

# Inference

In [59]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 68.7285%
Percentage Improvement Confidence Interval Achieved: 0.7098%
Mean MSE: 0.3530
95% Confidence Interval: (0.3450, 0.3610)
Margin of Error: 0.0080


# Change Dictionary

In [60]:
onoffdict={'GPS': False, 'CAMERAS': True, 'RADAR': True}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [61]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [62]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [63]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [64]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [65]:
model=task3model(reduction)

# Inference

In [66]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 68.2396%
Percentage Improvement Confidence Interval Achieved: 0.4952%
Mean MSE: 0.3585
95% Confidence Interval: (0.3529, 0.3641)
Margin of Error: 0.0056


# Change Dictionary

In [67]:
onoffdict={'GPS': False, 'CAMERAS': False, 'RADAR': True}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [68]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [69]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [70]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [71]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [72]:
model=task3model(reduction)

# Inference

In [73]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 68.9248%
Percentage Improvement Confidence Interval Achieved: 0.6763%
Mean MSE: 0.3508
95% Confidence Interval: (0.3431, 0.3584)
Margin of Error: 0.0076


# Change Dictionary

In [74]:
onoffdict={'GPS': False, 'CAMERAS': True, 'RADAR': False}
weight_path=f'models/CSINettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [75]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [76]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [77]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task3Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict, input_autoregressive_features):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded, _ =self.bde(Hencoded, input_autoregressive_features)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        
        autoregressive_features = output
        #print(output.shape)
        output = output.view(bs, 1, num_H, num_H)

        return output, autoregressive_features

In [78]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = torch.load(weight_pathb+"task3Encoder.pth")
        self.de = Decoderwithmsi(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, time_index, device, is_training, onoffdict): 
         
       
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, gps, radar, left_cam, center_cam, right_cam, onoffdict, iarf)

        return Hdecoded


In [79]:
model=task3model(reduction)

# Inference

In [80]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 68.7134%
Percentage Improvement Confidence Interval Achieved: 1.1049%
Mean MSE: 0.3531
95% Confidence Interval: (0.3407, 0.3656)
Margin of Error: 0.0125
