In [1]:
import numpy as np
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # Select the GPU index
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
from PIL import Image
from collections import OrderedDict
import torchvision.transforms as T
import matplotlib.pyplot as plt
import time
from torch.optim.lr_scheduler import _LRScheduler
import warnings
from scipy.io import savemat
from scipy import stats
import dill as pickle
import thop
from torch_challenge_dataset import DeepVerseChallengeLoaderTaskThree
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
#Parameters
onoffdict={'GPS': False, 'CAMERAS': False, 'RADAR': False}
reduction = 4
weight_path=f'models/ACRNettask3/cr{reduction}/gps{onoffdict["GPS"]}_cam{onoffdict["CAMERAS"]}_rad{onoffdict["RADAR"]}/'

In [3]:
if not os.path.exists(weight_path):
    os.makedirs(weight_path)

In [4]:
batch_size=200


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils and Models

In [6]:
def CSI_abs_reshape(y, csi_std=2.8117975e-06, target_std=1.0):
    y = torch.abs(y)
    y=(y/csi_std)*target_std
    return y

In [7]:
def CSI_reshape( y, csi_std=2.5e-06, target_std=1):
        ry = torch.real(y)
        iy= torch.imag(y)
        oy=torch.cat([ry,iy],dim=1)
        #scaling
        oy=(oy/csi_std)*target_std
        return oy

In [8]:
def cal_model_parameters(model):
    total_param  = []
    for p1 in model.parameters():
        total_param.append(int(p1.numel()))
    return sum(total_param)

In [9]:
class ConvBN(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
        if not isinstance(kernel_size, int):
            padding = [(i - 1) // 2 for i in kernel_size]
        else:
            padding = (kernel_size - 1) // 2
        super(ConvBN, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_channels=in_planes,
                               out_channels=out_planes,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               groups=groups,
                               bias=False)),
            ('bn', nn.BatchNorm2d(out_planes))
        ]))

In [10]:
class ACRDecoderBlock(nn.Module):
    r""" Inverted residual with extensible width and group conv
    """
    def __init__(self, expansion=20):
        super(ACRDecoderBlock, self).__init__()
        width = 8 * expansion
        self.conv1_bn = ConvBN(2, width, [1, 9])
        self.prelu1 = nn.PReLU(num_parameters=width, init=0.3)
        self.conv2_bn = ConvBN(width, width, 7, groups=4 * expansion)
        self.prelu2 = nn.PReLU(num_parameters=width, init=0.3)
        self.conv3_bn = ConvBN(width, 2, [9, 1])
        self.prelu3 = nn.PReLU(num_parameters=2, init=0.3)
        self.identity = nn.Identity()

    def forward(self, x):
        identity = self.identity(x)

        residual = self.prelu1(self.conv1_bn(x))
        residual = self.prelu2(self.conv2_bn(residual))
        residual = self.conv3_bn(residual)

        return self.prelu3(identity + residual)


class ACREncoderBlock(nn.Module):
    def __init__(self):
        super(ACREncoderBlock, self).__init__()
        self.conv_bn1 = ConvBN(2, 2, [1, 9])
        self.prelu1 = nn.PReLU(num_parameters=2, init=0.3)
        self.conv_bn2 = ConvBN(2, 2, [9, 1])
        self.prelu2 = nn.PReLU(num_parameters=2, init=0.3)
        self.identity = nn.Identity()

    def forward(self, x):
        identity = self.identity(x)

        residual = self.prelu1(self.conv_bn1(x))
        residual = self.conv_bn2(residual)

        return self.prelu2(identity + residual)

In [11]:
class task2Encoder(nn.Module):
    
    def __init__(self, reduction=16, expansion=20):
        super(task2Encoder, self).__init__()
        self.total_size =8192
        n1=int(math.log2(reduction))
        self.encoder_feature = nn.Sequential(OrderedDict([
            ("conv5x5_bn", ConvBN(1, 2, 5)),
            ("prelu", nn.PReLU(num_parameters=2, init=0.3)),
            ("ACREncoderBlock1", ACREncoderBlock()),
            ("ACREncoderBlock2", ACREncoderBlock()),
        ]))
        self.encoder_fc = nn.Linear(self.total_size, self.total_size // reduction)
        self.output_sig = nn.Sigmoid()
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        n, c, h, w = x.detach().size()
        
        out = self.encoder_feature(x.to(torch.float32))
        out =  self.encoder_fc(out.view(n, -1))
        
        
        return out
       

In [12]:
class task2Decoder(nn.Module):
    
    def __init__(self, reduction=16, expansion=20):
        super(task2Decoder, self).__init__()
        self.total_size = 8192
        w, h =64, 64
        self.reduced_size = self.total_size//reduction
        self.decoder_fc1 = nn.Linear(self.total_size // reduction, self.total_size)
        self.decoder_feature = nn.Sequential(OrderedDict([
            ("conv5x5_bn", ConvBN(2, 2, 5)),
            ("prelu", nn.PReLU(num_parameters=2, init=0.3)),
            ("ACRDecoderBlock1", ACRDecoderBlock(expansion=expansion)),
            ("ACRDecoderBlock2", ACRDecoderBlock(expansion=expansion)),
            ("sigmoid", nn.Sigmoid())
        ]))
        
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        self.decoder_fc2 = nn.Linear(self.total_size, self.total_size//2)
        self.sig2 = nn.Sigmoid()
    
    def forward(self, Hencoded):
        bs = Hencoded.size(0)
        #combining
        out = Hencoded.view(bs, self.reduced_size)
        # Generate final output
        out = self.decoder_fc1(out)
        out = self.decoder_feature(out.view(bs, -1, 64, 64))
        out = self.sig2(self.decoder_fc2(out.view(bs, -1)))
        
        return out.view(bs, -1, 64, 64)

In [13]:

task2weight_path=f'models/ACRNettask2/cr{reduction}/gpsFalse_camFalse_radFalse/'


In [14]:
class task3Encoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task3Encoder, self).__init__()
        

        #self.en=task2Encoder(reduction)
        #reduction value is already considered in the task2weight_path
        # loading preloaded values
        self.en=torch.load(task2weight_path+"task2Encoder.pth")
        self.allow_update = False  # Initially, do not allow weight updates
        
    def forward(self, x):
        # Reshape the parameters to match the batch size
        if self.allow_update:
            out = self.en(x)
        else:
            with torch.no_grad():
                out = self.en(x)
        
        encoded_features=out
        return encoded_features
       

In [15]:
class task3Decoder(nn.Module):
    
    def __init__(self, reduction=16):
        super(task3Decoder, self).__init__()
        self.total_size = 8192
        w, h =64, 64
        self.de = torch.load(task2weight_path+"task2Decoder.pth")
        
        #Layers for auto regression 
        self.a= nn.Parameter(torch.randn(self.total_size//2))
        self.b= nn.Parameter(torch.randn(self.total_size//2))
        self.c= nn.Parameter(torch.randn(self.total_size//2))
        self.d= nn.Parameter(torch.randn(self.total_size//2))
        self.allow_update = False  # Initially, do not allow weight updates
    
    
    
    def forward(self, Hencoded, input_autoregressive_features):
        bs = Hencoded.size(0)
        a = self.a.expand(bs, -1)
        b = self.b.expand(bs, -1)
        c = self.c.expand(bs, -1)
        d = self.d.expand(bs, -1)
        out_tminus1=input_autoregressive_features[:,0,:].view(bs,-1)
        out_tminus2=input_autoregressive_features[:,1,:].view(bs,-1)
        if self.allow_update:
            out_t = self.de(Hencoded)
        else:
            with torch.no_grad():
                out_t = self.de(Hencoded)
        #print(out_t.shape)
        out = (out_t.view(bs,-1)) * a + out_tminus1 * b + out_tminus2 * c + d
        
        autoregressive_features = out
        
        output = out.view(bs,1, 64, 64)
        
        return output, autoregressive_features
       

In [16]:
#complete task 3 model including encoder, decoder and channel
class task3model(nn.Module):
    def __init__(self, reduction=16):
        super().__init__()
        self.total_size = 8192
        self.reduced_size = self.total_size//reduction
        self.en = task3Encoder(reduction)
        self.de = task3Decoder(reduction)
        self.ar = [None] * 5  # List to store the AR variables
    
    
    
    def forward(self, X, time_index, device, is_training, onoffdict): 
         
        Hin = X
        batch_size = Hin.shape[0]
        Hencoded = self.en(Hin)
        
        Hreceived = Hencoded
            
        # Encoder
        if time_index == 0:
            iarf = torch.zeros((batch_size, 2, self.total_size//2), dtype=torch.float).to(device)
            Hdecoded, self.ar[0] = self.de(Hencoded, iarf)
    
        elif time_index==1:
            iarf=torch.cat([self.ar[0].view(batch_size, 1, self.total_size//2).detach(), torch.zeros((batch_size, 1, self.total_size//2), dtype=torch.float).to(device)], dim=1)
            Hdecoded, self.ar[1] = self.de(Hencoded, iarf)
            
        else:
            iarf = torch.cat([self.ar[time_index-1].view(batch_size, 1, self.total_size//2).detach(), self.ar[time_index-2].view(batch_size, 1, self.total_size//2).detach()], dim=1)
            Hdecoded, self.ar[time_index] = self.de(Hencoded, iarf)

        return Hdecoded


In [17]:
#Loss

#criterion=nn.BCELoss()
#criterion = nn.CrossEntropyLoss()
criterion= nn.MSELoss().to(device)

# Inference

In [18]:
def run_test(model, test_loader, device, criterion):
    num_test_batches = len(test_loader)
    model.eval()
    with torch.no_grad():
        mse1 = 0
        for b,t_x in enumerate(test_loader):
            model.ar = [None] * 5 
            for time_index,(X, y) in enumerate(t_x):
                y_test_reshaped = CSI_abs_reshape(y.to(device))
                Xin = CSI_abs_reshape(X[0].to(device))
                # Get the input and output for the given time index
                y_pred = model(Xin, time_index, device, is_training=True, onoffdict = onoffdict)
                mse0 = criterion(y_pred, y_test_reshaped) 
                mse1+=mse0  
        avg_mse=mse1/(5*num_test_batches)
    return avg_mse.item()

In [19]:
def calculate_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return mean, h

In [20]:
test_dataset = DeepVerseChallengeLoaderTaskThree(csv_path = r'./dataset_validation.csv')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [21]:
h_list = torch.tensor([])
for b,t_x in enumerate(test_loader):
        for time_index,(X, y) in enumerate(t_x):
            h = CSI_abs_reshape(y)
            h_list = torch.cat([h_list,h])
target_loss = torch.mean((torch.abs(h_list) - torch.mean(torch.abs(h_list))) ** 2)

In [22]:
num_runs =10

In [23]:

avg_mse_list = []
improvement_list = []
for _ in range(num_runs):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    model = torch.load(weight_path + "task3.pth").to(device)
    avg_mse = run_test(model, test_loader, device, criterion)
    avg_mse_list.append(avg_mse)
    improvement = (target_loss.item() - avg_mse) / target_loss.item() * 100
    improvement_list.append(improvement)
mean_mse, margin_of_error = calculate_confidence_interval(avg_mse_list)
improvement_mean, improvement_margin_of_error = calculate_confidence_interval(improvement_list)
print(f'Percentage Improvement Mean Achieved: {improvement_mean:.4f}%')
print(f'Percentage Improvement Confidence Interval Achieved: {improvement_margin_of_error:.4f}%')
print(f"Mean MSE: {mean_mse:.4f}")
print(f"95% Confidence Interval: ({mean_mse - margin_of_error:.4f}, {mean_mse + margin_of_error:.4f})")
print(f"Margin of Error: {margin_of_error:.4f}")

Percentage Improvement Mean Achieved: 39.8561%
Percentage Improvement Confidence Interval Achieved: 1.2695%
Mean MSE: 0.6789
95% Confidence Interval: (0.6645, 0.6932)
Margin of Error: 0.0143
