In [1]:
import numpy as np
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Select the GPU index
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.utils.data import Dataset, DataLoader
import math
from PIL import Image
from collections import OrderedDict
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import matplotlib.pyplot as plt
import time
from torch.optim.lr_scheduler import _LRScheduler
import warnings
import spacy
from scipy.io import savemat
from scipy import stats
import dill as pickle
import thop
from typing import Optional, Tuple, Any
from typing import List, Optional, Tuple
from torch_challenge_dataset import DeepVerseChallengeLoaderTaskTwo
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
Tensor = torch.Tensor

In [3]:
#Parameters
onoffdict={'GPS': True, 'CAMERAS': True, 'RADAR': True}
lr=1e-3
num_epochs=100

reduction = 16
batch_size = 200
num_H = 64
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [4]:
reduction

16

In [5]:
weight_path

'models/TransNettask2/cr16/gpsTrue_camTrue_radTrue/'

In [6]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
train_dataset = DeepVerseChallengeLoaderTaskTwo(csv_path = r'./dataset_train.csv')
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=5)
test_dataset = DeepVerseChallengeLoaderTaskTwo(csv_path = r'./dataset_validation.csv')
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=True, num_workers=5)

# Utils and Models

In [9]:
def CSI_abs_reshape(y, csi_std=2.8117975e-06, target_std=1.0):
    y = torch.abs(y)
    y=(y/csi_std)*target_std
    return y

In [10]:
def CSI_reshape( y, csi_std=2.5e-06, target_std=1):
        ry = torch.real(y)
        iy= torch.imag(y)
        oy=torch.cat([ry,iy],dim=1)
        #scaling
        oy=(oy/csi_std)*target_std
        return oy

In [11]:
def cal_model_parameters(model):
    total_param  = []
    for p1 in model.parameters():
        total_param.append(int(p1.numel()))
    return sum(total_param)

In [12]:
def normalize_image(image):
    # Convert image to float tensor
    image = image.float()
    # Normalize the image
    image /= 255.0
    # ImageNet mean values # ImageNet standard deviation values
    trans=T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) 
    image=trans(image)
    return image

In [13]:
def left_coordinates_batch(x_cor, y_cor):
    y_pix = torch.zeros_like(x_cor)
    x_pix = torch.zeros_like(y_cor)

    condition1 = y_cor < -4
    condition2 = (y_cor >= -4) & (y_cor < -1)
    condition3 = (y_cor >= -1) & (y_cor < 1)
    condition4 = (y_cor >= 1) & (y_cor < 4)
    condition5 = y_cor >= 4

    y_pix[condition1] = 100 + (250 - 100) * ((x_cor[condition1] - 80) / (125 - 80))
    y_pix[condition2] = 100 + (250 - 100) * ((x_cor[condition2] - 80) / (125 - 80))
    y_pix[condition3] = 100 + (250 - 100) * ((x_cor[condition3] - 80) / (125 - 80))
    y_pix[condition4] = 100 + (210 - 100) * ((x_cor[condition4] - 80) / (125 - 80))
    y_pix[condition5] = 100 + (190 - 100) * ((x_cor[condition5] - 80) / (125 - 80))

    x_pix[condition1] = (y_pix[condition1] - 30) / 1.35
    x_pix[condition2] = (y_pix[condition2] - 45) / 0.85
    x_pix[condition3] = (y_pix[condition3] - 55) / 0.70
    x_pix[condition4] = (y_pix[condition4] - 65) / 0.60
    x_pix[condition5] = (y_pix[condition5] - 65) / 0.5
    return x_pix, y_pix

In [14]:
def center_coordinates_batch(x_cor, y_cor):
    x_pix = torch.zeros_like(x_cor)
    y_pix = torch.zeros_like(y_cor)

    condition = y_cor < 0
    x_pix[condition] = 256 * ((x_cor[condition] - 119) / (139 - 119))
    x_pix[~condition] = 256 * ((x_cor[~condition] - 112) / (146 - 113))
    
    y_pix = 175 + (100 - 175) * ((y_cor - (-7)) / ((7) - (-7)))
    return x_pix, y_pix

In [15]:
def right_coordinates_batch(x_cor, y_cor):
    y_pix = torch.zeros_like(x_cor)
    x_pix = torch.zeros_like(y_cor)

    condition1 = y_cor < -4
    condition2 = (y_cor >= -4) & (y_cor < -1)
    condition3 = (y_cor >= -1) & (y_cor < 1)
    condition4 = (y_cor >= 1) & (y_cor < 4)
    condition5 = y_cor >= 4

    y_pix[condition1] = 250 + (100 - 250) * ((x_cor[condition1] - 125) / (200 - 125))
    y_pix[condition2] = 250 + (100 - 250) * ((x_cor[condition2] - 125) / (200 - 125))
    y_pix[condition3] = 250 + (100 - 250) * ((x_cor[condition3] - 125) / (200 - 125))
    y_pix[condition4] = 210 + (100 - 210) * ((x_cor[condition4] - 125) / (200 - 125))
    y_pix[condition5] = 190 + (100 - 190) * ((x_cor[condition5] - 125) / (200 - 125))

    x_pix[condition1] = -(y_pix[condition1] - 370) / 1.25
    x_pix[condition2] = -(y_pix[condition2] - 285) / 0.87
    x_pix[condition3] = -(y_pix[condition3] - 250) / 0.73
    x_pix[condition4] = -(y_pix[condition4] - 210) / 0.55
    x_pix[condition5] = -(y_pix[condition5] - 190) / 0.45
    return x_pix, y_pix

In [16]:
def center_image_batch(images, center_x, center_y, output_size, bounded=False):
    batch_size = images.size(0)
    
    if bounded == 'left':
        top = torch.clamp(center_y - output_size[0] // 2, 0, None)
        left = torch.clamp(center_x - output_size[1] // 2, 0, None)
    elif bounded == 'right':
        bottom = center_y + output_size[0] // 2
        right = torch.clamp(center_x + output_size[1] // 2, None, 250)
        top = torch.clamp(bottom - output_size[0], 0, None)
        left = right - output_size[1]
    else:
        top = center_y - output_size[0] // 2
        left = center_x - output_size[1] // 2

    resize_transform = transforms.Resize((output_size))
    cropped_images = [TF.crop(image, int(top[i].item()), int(left[i].item()), output_size[0], output_size[1]) 
                        for i, image in enumerate(images)]
    cropped_images = torch.stack([resize_transform(image) for image in cropped_images])
    return cropped_images

In [17]:
def process_imgs(gps, img_1, img_2, img_3, crop_size = (150,150)):
    x_cor = gps[:,0]
    y_cor = gps[:,1]

    x_pix,y_pix = left_coordinates_batch(x_cor, y_cor)
    img_1 = center_image_batch(img_1, x_pix.to(torch.int), y_pix.to(torch.int), crop_size, 'left')

    x_pix,y_pix = center_coordinates_batch(x_cor, y_cor)
    img_2 = center_image_batch(img_2, x_pix.to(torch.int), y_pix.to(torch.int), crop_size)

    x_pix,y_pix = right_coordinates_batch(x_cor, y_cor)
    img_3 = center_image_batch(img_3, x_pix.to(torch.int), y_pix.to(torch.int), crop_size, 'right')

    return img_1, img_2, img_3

Scheduler

In [18]:
class gpsdata(nn.Module):
    def __init__(self):
        super().__init__()
        self.gps_fc = nn.Linear(2, 16)
        self.gps_relu = nn.ReLU()

    def forward(self, gps):  
        gps = gps.to(torch.float32)

        x, y = gps[:,0], gps[:,1]
        x_normd = (x - torch.min(x)) / (torch.max(x) - torch.min(x))
        y_normd = (y - torch.min(y)) / (torch.max(y) - torch.min(y))
        gps_normd = torch.stack([x_normd,y_normd],dim=1)

        gps_out = self.gps_fc(gps_normd)  
        gps_out = self.gps_relu(gps_out)
        return gps_out

Radar Data Processing Layer

In [19]:
class radardata(nn.Module):
    
    def __init__(self):
        super(radardata, self).__init__()
        self.dropout = nn.Dropout(p=0.5)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.lr1=nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.encoder_fc = nn.Linear(256,16)
                
    def forward(self, x):
        x=x.view(-1,1,512,128)
        x = (x - 5.1838e-06) / (28.0494 - 5.1838e-06)
        out = self.pool1(self.dropout(self.conv1(x)))
        out = self.pool2(self.dropout(self.conv2(out))).view(x.size(0), -1)
        out = self.dropout(self.encoder_fc(out))
        out = self.lr1(out)
        return out 

Camera Data Processing Layers

In [20]:
class cameradata(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.lr1=nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.encoder = nn.Linear(1*81,16)
        self.dropout = nn.Dropout(0.5)

    def forward(self, cam):  
        cam = normalize_image(cam).to(torch.float32)
        out = self.pool1(self.dropout(self.conv1(cam)))
        out = self.pool2(self.dropout(self.conv2(out)))
        out = self.lr1(out).view(-1,1*81)
        out = self.dropout(self.encoder(out))
        return out

In [21]:
def scale_dot_attention(
       q:Tensor,
       k:Tensor,
       v:Tensor,
       dropout_p:float = 0.0,
       attn_mask : Optional[Tensor] = None,
)-> Tuple[Tensor,Tensor]:
    
    _,_,E = q.shape
    q = q / math.sqrt(E)
    attn = torch.bmm(q,k.transpose(-2,-1))
    if attn_mask is not None:
        attn = attn + attn_mask
    attn = F.softmax(attn,dim =-1)
    if dropout_p:
        attn = F.dropout(attn,p = dropout_p)
    out = torch.bmm(attn,v)

    return out,attn


In [22]:
def multi_head_attention_forward(
        query: Tensor,
        key: Tensor,
        value: Tensor,
        num_heads: int,
        in_proj_weight: Tensor,
        in_proj_bias: Optional[Tensor],
        dropout_p: float,
        out_proj_weight: Tensor,
        out_proj_bias: Optional[Tensor],
        training: bool = True,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
        use_separate_proj_weight=None,
        q_proj_weight: Optional[Tensor] = None,
        k_proj_weight: Optional[Tensor] = None,
        v_proj_weight: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    head_dim = embed_dim // num_heads
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    
    if attn_mask is not None:
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        else:
            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"

        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn(
            "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)


    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (bsz, src_len), \
            f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        if attn_mask is None:
            attn_mask = key_padding_mask
        elif attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
  
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    
    if not training:
        dropout_p = 0.0
    attn_output, attn_output_weights = scale_dot_attention(q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None

In [23]:
def _in_projection_packed(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor] = None,
) -> List[Tensor]:
    E = q.size(-1)
    if k is v:
        if q is k:
            return F.linear(q, w, b).chunk(3, dim=-1)
        else:
            w_q, w_kv = w.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
    else:
        w_q, w_k, w_v = w.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)

In [24]:
class MultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True,
                 kdim=None, vdim=None, batch_first=False) -> None:
        # factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == self.embed_dim and self.vdim == self.embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim)))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim)))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim)))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim,embed_dim)))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self._reset_parameters()
        
    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)



    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
        if self.batch_first:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            attn_output, attn_output_weights = multi_head_attention_forward(
                query, key, value, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
            
        if self.batch_first:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

In [25]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
                 layer_norm_eps=1e-5, batch_first=False) -> None:
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model,nhead,
                                            dropout=dropout, batch_first=batch_first)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout(src2)
        src = self.norm2(src)
        return src

In [26]:
class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layer = encoder_layer
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = src
        for _ in range(self.num_layers):
            output = self.layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [27]:
#Decoder Layer:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
                 layer_norm_eps=1e-5, batch_first=False) -> None:
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model,nhead,
                                            dropout=dropout, batch_first=batch_first)
        self.multihead_attn = MultiheadAttention(d_model,nhead,dropout=dropout, batch_first=batch_first)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

In [28]:
#Decoder
class TransformerDecoder(nn.Module):
    
    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layer = decoder_layer
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = tgt
        for _ in range(self.num_layers):
            output = self.layer(output, memory, tgt_mask=tgt_mask,
                                memory_mask=memory_mask,
                                tgt_key_padding_mask=tgt_key_padding_mask,
                                memory_key_padding_mask=memory_key_padding_mask)
        if self.norm is not None:
            output = self.norm(output)

        return output


In [29]:
class task2Encoder(nn.Module):
    
    def __init__(self,  d_model: int = 64, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False,  reduction=64) -> None:
        super(task2Encoder, self).__init__()
        self.total_size =8192
       
        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first)
            encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
       
        self.d_model = d_model
        
        assert not (self.total_size % self.d_model), 'd_model needs to be divisible by the size of the entire csi matrix (2048)'
        self.feature_shape = (self.total_size//(2*self.d_model), self.d_model)

        self.nhead = nhead

        self.batch_first = batch_first
        self.fc_encoder = nn.Linear(self.total_size//2,self.total_size//reduction)
        #self.fc_decoder = nn.Linear(self.total_size//reduction,self.total_size)
        self._reset_parameters()

                
        
    def forward(self, src: Tensor, tgt: Optional[Tensor]=None, src_mask: Optional[Tensor] = None,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        
        memory = self.encoder(src.view(-1, self.feature_shape[0], self.feature_shape[1]), mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        memory_encoder = self.fc_encoder(memory.view(memory.shape[0],-1))
        
        return memory_encoder
    def generate_square_subsequent_mask(self, sz: int) -> Tensor:
        
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)

        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _reset_parameters(self):
   
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)  

In [30]:
class task2Decoder(nn.Module):
    
    def __init__(self,  d_model: int = 64, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False,  reduction=64) -> None:
        super(task2Decoder, self).__init__()
        self.total_size = 8192
        w, h =64, 64
        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, batch_first)
            decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
        
        self.d_model = d_model
        
        assert not (self.total_size % self.d_model), 'd_model needs to be divisible by the size of the entire csi matrix (2048)'
        self.feature_shape = (self.total_size//self.d_model, self.d_model)

        self.nhead = nhead

        self.batch_first = batch_first
        #self.fc_encoder = nn.Linear(self.total_size,self.total_size//reduction)
        self.fc_decoder = nn.Linear(self.total_size//reduction,self.total_size)
        self._reset_parameters()
        
        self.decoder_fc2 = nn.Linear(self.total_size, self.total_size//2)
        self.sig2 = nn.Sigmoid()
        
    
    def forward(self, memory_encoder, tgt: Optional[Tensor]=None, src_mask: Optional[Tensor] = None,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        bs = memory_encoder.size(0)
        
        # Generate final output
        memory_decoder = self.fc_decoder(memory_encoder).view(-1, self.feature_shape[0], self.feature_shape[1])
        output = self.decoder(memory_decoder, memory_decoder, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        
        output = self.sig2(self.decoder_fc2(output.view(bs, -1)))
        
        return output.view(bs, -1, 64, 64)
    
    def generate_square_subsequent_mask(self, sz: int) -> Tensor:
        
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)

        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _reset_parameters(self):
   
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

In [31]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=task2Encoder(d_model= 32, num_encoder_layers=2, num_decoder_layers=2, nhead=2, reduction =reduction, dropout= 0.)
        
        self.de=task2Decoder(d_model=32, num_encoder_layers=2, num_decoder_layers=2, nhead=2, reduction =reduction, dropout= 0.)
        
    
   
    def forward(self, Hin, device, is_training): 
        
        #Encoder
        Hencoded=self.en(Hin)
        
        
        #Decoder   
        Hdecoded=self.de(Hencoded)
        

        return Hdecoded

In [32]:
#loading weights of baseline model and task 1
onoffdictb={'GPS': False, 'CAMERAS': False, 'RADAR': False} #baseline dictionary
weight_pathb=f'models/TransNettask2/cr{reduction}/gps{onoffdictb["GPS"]}_cam{onoffdictb["CAMERAS"]}_rad{onoffdictb["RADAR"]}/'

In [33]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [34]:
class task1decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.gp = gpsdata()
        self.rd = radardata()
        self.lc = cameradata()
        self.cc = cameradata()
        self.rc = cameradata()

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = nn.Linear(16*5, int(num_H/2)*int(num_H/2))
            self.output_fc = nn.Linear(int(num_H/2)*int(num_H/2), num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = nn.Linear(16*5, 32)
            self.output_fc = nn.Linear(32, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = gps.size(0)
        
        if onoffdict['GPS']:
             gps_out = self.gp(gps)
        else:
             gps_out = torch.zeros(bs, 16).to(device)
        
        if onoffdict['RADAR']:
            radar_out = self.rd(radar)
        else:
            radar_out = torch.zeros(bs, 16).to(device)
        
        if onoffdict['CAMERAS']:
            left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
            lc_out = self.lc(left_cam)
            cc_out = self.cc(center_cam)
            rc_out = self.rc(right_cam)
        else:
            lc_out = torch.zeros(bs, 16).to(device)
            cc_out = torch.zeros(bs, 16).to(device)
            rc_out = torch.zeros(bs, 16).to(device)

        combined = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)
        
        output = self.linear(combined)
        output = self.output_relu(output)
        output = self.output_fc(output)
        output = self.output_relu(output)
        output = output.view(output.size(0), 1, num_H, num_H)

        return output

In [35]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [36]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

In [37]:
#Loss

#criterion=nn.BCELoss()
#criterion = nn.CrossEntropyLoss()
criterion= nn.MSELoss().to(device)

# Inference

In [38]:
def run_test(model, test_loader, device, criterion):
    num_test_batches = len(test_loader)
    model.eval()
    with torch.no_grad():
        mse1 = 0
        for b, (X_test, y_test) in enumerate(test_loader):
            y_test = y_test.to(device)
            Xin = CSI_abs_reshape(X_test[0])
            y_pred = model(Xin.to(device),X_test[1].to(device),X_test[2].to(device),X_test[3].to(device),X_test[4].to(device),X_test[5].to(device), onoffdict = onoffdict)
            y_test_reshaped = CSI_abs_reshape(y_test)
            mse0 = criterion(y_pred, y_test_reshaped) 
            mse1 += mse0 
        
    avg_mse = mse1 / num_test_batches
    return avg_mse.item()

In [39]:
def calculate_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return mean, h

In [40]:
test_dataset = DeepVerseChallengeLoaderTaskTwo(csv_path = r'./dataset_validation.csv')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)

In [41]:

h_list = torch.tensor([])
for b, (x,h) in enumerate(test_loader):
    h = CSI_abs_reshape(h)
    h_list = torch.cat([h_list,h])
target_loss = torch.mean((torch.abs(h_list) - torch.mean(torch.abs(h_list))) ** 2)

In [42]:
num_runs =10

In [43]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 56.6364%
Percentage Improvement Confidence Interval Achieved: 0.4498%
Mean MSE: 0.4838
95% Confidence Interval: (0.4788, 0.4889)
Margin of Error: 0.0050


# Change Dictionary

In [44]:
onoffdict={'GPS': True, 'CAMERAS': True, 'RADAR': False}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [45]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [46]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [47]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [48]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

# Inference

In [49]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 56.0821%
Percentage Improvement Confidence Interval Achieved: 0.6571%
Mean MSE: 0.4900
95% Confidence Interval: (0.4827, 0.4974)
Margin of Error: 0.0073


# Change Dictionary

In [50]:
onoffdict={'GPS': True, 'CAMERAS': False, 'RADAR': True}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [51]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [52]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [53]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [54]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

In [55]:
model=task2model(reduction)


# Inference

In [56]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 56.0597%
Percentage Improvement Confidence Interval Achieved: 0.4678%
Mean MSE: 0.4903
95% Confidence Interval: (0.4851, 0.4955)
Margin of Error: 0.0052


# Change Dictionary

In [57]:
onoffdict={'GPS': True, 'CAMERAS': False, 'RADAR': False}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [58]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [59]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [60]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [61]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

# Inference

In [62]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 55.0498%
Percentage Improvement Confidence Interval Achieved: 0.4695%
Mean MSE: 0.5015
95% Confidence Interval: (0.4963, 0.5068)
Margin of Error: 0.0052


# Change Dictionary

In [63]:
onoffdict={'GPS': False, 'CAMERAS': True, 'RADAR': True}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [64]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [65]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [66]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [67]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

In [68]:
model=task2model(reduction)


In [69]:
# Training

In [70]:
#Loss

#criterion=nn.BCELoss()
#criterion = nn.CrossEntropyLoss()
criterion= nn.MSELoss().to(device)

# Inference

In [71]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 53.5477%
Percentage Improvement Confidence Interval Achieved: 0.6450%
Mean MSE: 0.5183
95% Confidence Interval: (0.5111, 0.5255)
Margin of Error: 0.0072


# Change Dictionary

In [72]:
onoffdict={'GPS': False, 'CAMERAS': True, 'RADAR': False}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [73]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [74]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [75]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [76]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

In [77]:
model=task2model(reduction)


In [78]:
# Training

In [79]:
#Loss

#criterion=nn.BCELoss()
#criterion = nn.CrossEntropyLoss()
criterion= nn.MSELoss().to(device)

# Inference

In [80]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 53.6652%
Percentage Improvement Confidence Interval Achieved: 0.5160%
Mean MSE: 0.5170
95% Confidence Interval: (0.5112, 0.5228)
Margin of Error: 0.0058


# Change Dictionary

In [81]:
onoffdict={'GPS': False, 'CAMERAS': False, 'RADAR': True}
weight_path=f'models/TransNettask2/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [82]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [83]:
task1_weight_path=f'models/task1/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [84]:
class Decoderwithmsi(nn.Module):
    def __init__(self, reduction):
        super().__init__()
        self.task1decoder = torch.load(task1_weight_path+"task1Decoder.pth")
        self.gp = self.task1decoder.gp
        self.rd = self.task1decoder.rd
        self.lc = self.task1decoder.lc
        self.cc = self.task1decoder.cc
        self.rc = self.task1decoder.rc
        self.bde = torch.load(weight_pathb+"task2Decoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        #self.bde = baselinedecoder(reduction) 

        if int(num_H/2)*int(num_H/2) > 32:
            self.linear = self.task1decoder.linear
            self.output_fc1= nn.Linear(int(num_H/2)*int(num_H/2)+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        else:
            self.linear = self.task1decoder.linear
            self.output_fc1 = nn.Linear(32+num_H*num_H, 2*num_H*num_H)
            self.output_fc2 = nn.Linear(2*num_H*num_H, num_H*num_H)
            self.output_relu = nn.ReLU()
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict):
        bs = Hencoded.size(0)
        if self.allow_update:
            Hdecoded=self.bde(Hencoded)
            if onoffdict['GPS']:
                 gps_out = self.gp(gps)
            else:
                 gps_out = torch.zeros(bs, 16).to(device)

            if onoffdict['RADAR']:
                radar_out = self.rd(radar)
            else:
                radar_out = torch.zeros(bs, 16).to(device)

            if onoffdict['CAMERAS']:
                left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                lc_out = self.lc(left_cam)
                cc_out = self.cc(center_cam)
                rc_out = self.rc(right_cam)
            else:
                lc_out = torch.zeros(bs, 16).to(device)
                cc_out = torch.zeros(bs, 16).to(device)
                rc_out = torch.zeros(bs, 16).to(device)

            combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

            output = self.linear(combined1)
        else:
            with torch.no_grad():
                Hdecoded=self.bde(Hencoded)
                if onoffdict['GPS']:
                     gps_out = self.gp(gps)
                else:
                     gps_out = torch.zeros(bs, 16).to(device)

                if onoffdict['RADAR']:
                    radar_out = self.rd(radar)
                else:
                    radar_out = torch.zeros(bs, 16).to(device)

                if onoffdict['CAMERAS']:
                    left_cam, center_cam, right_cam = process_imgs(gps, left_cam, center_cam, right_cam, crop_size = (150,150))
                    lc_out = self.lc(left_cam)
                    cc_out = self.cc(center_cam)
                    rc_out = self.rc(right_cam)
                else:
                    lc_out = torch.zeros(bs, 16).to(device)
                    cc_out = torch.zeros(bs, 16).to(device)
                    rc_out = torch.zeros(bs, 16).to(device)

                combined1 = torch.cat((gps_out, radar_out, lc_out, cc_out, rc_out), dim=1)

                output = self.linear(combined1)
                
        
        
        
        output = self.output_relu(output)
        combined2 = torch.cat((output, Hdecoded.view(bs,-1)), dim=1)
        output = self.output_fc1(combined2)
        output = self.output_relu(output)
        output = self.output_fc2(output)
        output = output.view(bs, 1, num_H, num_H)

        return output

In [85]:
#complete task 2 model including encoder, decoder and channel
class task2model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        
        self.en=torch.load(weight_pathb+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
        self.de=Decoderwithmsi(reduction)
        
    
   
    def forward(self, Hin, gps, radar, left_cam, center_cam, right_cam, onoffdict): 
        
        #Encoder
        if self.allow_update:
            Hencoded=self.en(Hin)
        else:
            with torch.no_grad():
                Hencoded=self.en(Hin)
        
        
        
        
        #Decoder
        
        Hdecoded=self.de(Hencoded,gps,radar,left_cam,center_cam,right_cam,onoffdict)
        

        return Hdecoded

In [86]:
model=task2model(reduction)


In [87]:
# Training

# Inference

In [88]:
avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    model = torch.load(weight_path + "task2.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 3.9752%
Percentage Improvement Confidence Interval Achieved: 1.0468%
Mean MSE: 1.0714
95% Confidence Interval: (1.0597, 1.0831)
Margin of Error: 0.0117
