<a href="https://colab.research.google.com/github/wongdongwook/DeepJSCC-V-FeatureSelector-/blob/main/DeepJSSC_V_(Feature_Selector).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#GDN
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
# from torchvision import datasets, transforms
# from torchvision.utils import save_image
from torch.autograd import Function


class LowerBound(Function):
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones_like(inputs) * bound
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b)

    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors
        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None


class GDN(nn.Module):
    """Generalized divisive normalization layer.
    y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
    """

    def __init__(self,
                 ch,
                 inverse=False,
                 beta_min=1e-6,
                 gamma_init=0.1,
                 reparam_offset=2**-18):
        super(GDN, self).__init__()
        self.inverse = inverse
        self.beta_min = beta_min
        self.gamma_init = gamma_init
        self.reparam_offset = reparam_offset

        self.build(ch)

    def build(self, ch):
        self.pedestal = self.reparam_offset**2
        self.beta_bound = ((self.beta_min + self.reparam_offset**2)**0.5)
        self.gamma_bound = self.reparam_offset

        # Create beta param
        beta = torch.sqrt(torch.ones(ch)+self.pedestal)
        self.beta = nn.Parameter(beta)

        # Create gamma param
        eye = torch.eye(ch)
        g = self.gamma_init*eye
        g = g + self.pedestal
        gamma = torch.sqrt(g)

        self.gamma = nn.Parameter(gamma)
        self.pedestal = self.pedestal

    def forward(self, inputs):
        unfold = False
        if inputs.dim() == 5:
            unfold = True
            bs, ch, d, w, h = inputs.size()
            inputs = inputs.view(bs, ch, d*w, h)

        _, ch, _, _ = inputs.size()

        # Beta bound and reparam
        beta = LowerBound.apply(self.beta, self.beta_bound)
        beta = beta**2 - self.pedestal

        # Gamma bound and reparam
        gamma = LowerBound.apply(self.gamma, self.gamma_bound)
        gamma = gamma**2 - self.pedestal
        gamma = gamma.view(ch, ch, 1, 1)

        # Norm pool calc
        norm_ = nn.functional.conv2d(inputs**2, gamma, beta)
        norm_ = torch.sqrt(norm_)

        # Apply norm
        if self.inverse:
            outputs = inputs * norm_
        else:
            outputs = inputs / norm_

        if unfold:
            outputs = outputs.view(bs, ch, d, w, h)
        return outputs

In [None]:
# models
import numpy as np
import torch.nn as nn
import torch

def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)

def deconv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding = output_padding,bias=False)


class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(conv_block, self).__init__()
        self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.gdn = nn.GDN(out_channels)
        self.prelu = nn.PReLU()
    def forward(self, x):
        out = self.conv(x)
        out = self.gdn(out)
        out = self.prelu(out)
        return out

class deconv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0):
        super(deconv_block, self).__init__()
        self.deconv = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,  output_padding = output_padding)
        self.gdn = nn.GDN(out_channels)
        self.prelu = nn.PReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, activate_func='prelu'):
        out = self.deconv(x)
        out = self.gdn(out)
        if activate_func=='prelu':
            out = self.prelu(out)
        elif activate_func=='sigmoid':
            out = self.sigmoid(out)
        return out


class conv_ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_conv1x1=False, kernel_size=3, stride=1, padding=1):
        super(conv_ResBlock, self).__init__()
        self.conv1 = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv2 = conv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0)
        self.gdn1 = GDN(out_channels)
        self.gdn2 = GDN(out_channels)
        self.prelu = nn.PReLU()
        self.use_conv1x1 = use_conv1x1
        if use_conv1x1 == True:
            self.conv3 = conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
    def forward(self, x):
        out = self.conv1(x)
        out = self.gdn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.gdn2(out)
        if self.use_conv1x1 == True:
            x = self.conv3(x)
        out = out+x
        out = self.prelu(out)
        return out


class deconv_ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_deconv1x1=False, kernel_size=3, stride=1, padding=1, output_padding=0):
        super(deconv_ResBlock, self).__init__()
        self.deconv1 = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding)
        self.deconv2 = deconv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0, output_padding=0)
        self.gdn1 = GDN(out_channels)
        self.gdn2 = GDN(out_channels)
        self.prelu = nn.PReLU()
        self.sigmoid = nn.Sigmoid()
        self.use_deconv1x1 = use_deconv1x1
        if use_deconv1x1 == True:
            self.deconv3 = deconv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, output_padding=output_padding)
    def forward(self, x, activate_func='prelu'):
        out = self.deconv1(x)
        out = self.gdn1(out)
        out = self.prelu(out)
        out = self.deconv2(out)
        out = self.gdn2(out)
        if self.use_deconv1x1 == True:
            x = self.deconv3(x)
        out = out+x
        if activate_func=='prelu':
            out = self.prelu(out)
        elif activate_func=='sigmoid':
            out = self.sigmoid(out)
        return out

# Original Existing Works
class AF_block(nn.Module):
    def __init__(self, Nin, Nh, No):
        super(AF_block, self).__init__()
        self.fc1 = nn.Linear(Nin+1, Nh)
        self.fc2 = nn.Linear(Nh, No)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, snr):
        # out = F.adaptive_avg_pool2d(x, (1,1))
        # out = torch.squeeze(out)
        # out = torch.cat((out, snr), 1)
        if snr.shape[0]>1:
            snr = snr.squeeze()
        snr = snr.unsqueeze(1)
        mu = torch.mean(x, (2, 3))
        out = torch.cat((mu, snr), 1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.unsqueeze(2)
        out = out.unsqueeze(3)
        out = out*x
        return out

# The Encoder model with attention feature blocks
class Encoder(nn.Module):
    def __init__(self, enc_shape, kernel_sz, Nc_conv):
        super(Encoder, self).__init__()
        enc_N = enc_shape[0]
        Nh_AF = Nc_conv//2
        padding_L = (kernel_sz-1)//2
        self.conv1 = conv_ResBlock(3, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) # 데이터셋의 채널수에따라서 1이 3으로 바뀌고 3이 1로 바뀔 수 있음
        self.conv2 = conv_ResBlock(Nc_conv, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L)
        self.conv3 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.conv4 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.conv5 = conv_ResBlock(Nc_conv, enc_N, use_conv1x1=True, kernel_size = kernel_sz, stride = 1, padding=padding_L)
        self.AF1 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF2 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF3 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF4 = AF_block(Nc_conv, Nh_AF, Nc_conv)
        self.AF5 = AF_block(enc_N, enc_N//2, enc_N)
        self.flatten = nn.Flatten()
    def forward(self, x, snr):
        #snr = snr.view(-1, 1)
        out = self.conv1(x)
        out = self.AF1(out, snr)
        out = self.conv2(out)
        out = self.AF2(out, snr)
        out = self.conv3(out)
        out = self.AF3(out, snr)
        out = self.conv4(out)
        out = self.AF4(out, snr)
        out = self.conv5(out)
        out = self.AF5(out, snr)
        out = self.flatten(out)
        return out

# The Decoder model with attention feature blocks
class Decoder(nn.Module):
    def __init__(self, enc_shape, kernel_sz, Nc_deconv):
        super(Decoder, self).__init__()
        self.enc_shape = enc_shape
        Nh_AF1 = enc_shape[0]//2
        Nh_AF = Nc_deconv//2
        padding_L = (kernel_sz-1)//2
        self.deconv1 = deconv_ResBlock(self.enc_shape[0], Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2,  padding=padding_L, output_padding = 1)
        self.deconv2 = deconv_ResBlock(Nc_deconv, Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2,  padding=padding_L, output_padding = 1)
        self.deconv3 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L)
        self.deconv4 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L)
        self.deconv5 = deconv_ResBlock(Nc_deconv, 3, use_deconv1x1=True, kernel_size=kernel_sz, stride=1, padding=padding_L)

        self.AF1 = AF_block(self.enc_shape[0], Nh_AF1, self.enc_shape[0])
        self.AF2 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF3 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF4 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
        self.AF5 = AF_block(Nc_deconv, Nh_AF, Nc_deconv)
    def forward(self, x, snr):
        #snr = snr.view(-1, 1)
        out = x.view(-1, self.enc_shape[0], self.enc_shape[1], self.enc_shape[2])
        out = self.AF1(out, snr)
        out = self.deconv1(out)
        out = self.AF2(out, snr)
        out = self.deconv2(out)
        out = self.AF3(out, snr)
        out = self.deconv3(out)
        out = self.AF4(out, snr)
        out = self.deconv4(out)
        out = self.AF5(out, snr)
        out = self.deconv5(out, 'sigmoid')
        return out





# Power normalization before transmission
# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm(z, P = 1):
    batch_size, z_dim = z.shape
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1)
    return np.sqrt(P*z_dim)*z/z_M.t()

def Power_norm_complex(z, P = 1):
    batch_size, z_dim = z.shape
    z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
    z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
    z_power = torch.sum(z_com*z_com_conj, 1).real
    z_M = z_power.repeat(z_dim//2, 1)
    z_nlz = np.sqrt(P*z_dim)*z_com/torch.sqrt(z_M.t())
    z_out = torch.zeros(batch_size, z_dim).cuda()
    z_out[:, 0:z_dim:2] = z_nlz.real
    z_out[:, 1:z_dim:2] = z_nlz.imag
    return z_out

# The (real) AWGN channel
def AWGN_channel(x, snr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    noise = torch.sqrt(P/gamma)*torch.randn(batch_size, length).cuda()
    y = x+noise
    return y

def AWGN_complex(x, snr, Ps = 1):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    n_I = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    n_R = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda()
    noise = torch.complex(n_I, n_R)
    y = x + noise
    return y

# Please set the symbol power if it is not a default value
def Fading_channel(x, snr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out



# Note: if P = 1, the symbol power is 2
# If you want to set the average power as 1, please change P as P=1/np.sqrt(2)
def Power_norm_VLC(z, cr, P = 1):
    batch_size, z_dim = z.shape
    Kv = torch.ceil(z_dim*cr).int()
    z_power = torch.sqrt(torch.sum(z**2, 1))
    z_M = z_power.repeat(z_dim, 1).cuda()
    return torch.sqrt(Kv*P)*z/z_M.t()


def AWGN_channel_VLC(x, snr, cr, P = 2):
    batch_size, length = x.shape
    gamma = 10 ** (snr / 10.0)
    mask = mask_gen(length, cr).cuda()
    noise = torch.sqrt(P/gamma)*torch.randn(1, length).cuda()
    noise = noise*mask
    y = x+noise
    return y


def Fading_channel_VLC(x, snr, cr, P = 2):
    gamma = 10 ** (snr / 10.0)
    [batch_size, feature_length] = x.shape
    K = feature_length//2

    mask = mask_gen(K, cr).cuda()
    h_I = torch.randn(batch_size, K).cuda()
    h_R = torch.randn(batch_size, K).cuda()
    h_com = torch.complex(h_I, h_R)
    x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
    y_com = h_com*x_com

    n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda()
    noise = torch.complex(n_I, n_R)*mask

    y_add = y_com + noise
    y = y_add/h_com

    y_out = torch.zeros(batch_size, feature_length).cuda()
    y_out[:, 0:feature_length:2] = y.real
    y_out[:, 1:feature_length:2] = y.imag
    return y_out


def Channel(z, snr, channel_type = 'AWGN'):
    z = Power_norm(z)
    if channel_type == 'AWGN':
        z = AWGN_channel(z, snr)
    elif channel_type == 'Fading':
        z = Fading_channel(z, snr)
    return z


def Channel_VLC(z, snr, cr, channel_type = 'AWGN'):
    z = Power_norm_VLC(z, cr)
    if channel_type == 'AWGN':
        z = AWGN_channel_VLC(z, snr, cr)
    elif channel_type == 'Fading':
        z = Fading_channel_VLC(z, snr, cr)
    return z


def mask_gen(N, cr, ch_max = 48):
    MASK = torch.zeros(cr.shape[0], N).int()
    nc = N//ch_max
    for i in range(0, cr.shape[0]):
        L_i = nc*torch.round(ch_max*cr[i]).int()
        MASK[i, 0:L_i] = 1
    return MASK


class ADJSCC(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = Channel(z, snr, channel_type)
        out = self.decoder(z, snr)
        return out

# The DeepJSCC_V model, also called ADJSCC_V
class ADJSCC_V(nn.Module):
    def __init__(self, enc_shape, Kernel_sz, Nc):
        super(ADJSCC_V, self).__init__()
        self.encoder = Encoder(enc_shape, Kernel_sz, Nc)
        self.decoder = Decoder(enc_shape, Kernel_sz, Nc)
    def forward(self, x, snr, cr, channel_type = 'AWGN'):
        z = self.encoder(x, snr)
        z = z*mask_gen(z.shape[1], cr).cuda()
        z = Channel_VLC(z, snr, cr, channel_type)
        out = self.decoder(z, snr)
        return out

In [None]:
from keras.datasets import cifar100
from torch.utils.data import Dataset

# CIFAR-100 Data Loading Function
def load_cifar100_data():
    (x_train, y_train), (x_test, y_test) = cifar100.load_data()
    x_train = np.transpose(x_train, (0, 3, 1, 2))  # Dimension Rearragentment: [batch size, channel, height, width]
    x_test = np.transpose(x_test, (0, 3, 1, 2))    # Dimension Rearragentment
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255
    return x_train, x_test

# Data Loader 클래스
class DatasetFolder(Dataset):
    def __init__(self, matData):
        self.matdata = matData

    def __getitem__(self, index):
        return self.matdata[index]

    def __len__(self):
        return self.matdata.shape[0]

In [None]:
# import numpy as np
import torch
import torch.nn as nn
import os


BATCH_SIZE = 128
EPOCHS = 400
LEARNING_RATE = 1e-4
#LEARNING_RATE = 3e-4
PRINT_RREQ = 150

CHANNEL = 'AWGN'  # Choose AWGN or Fading
IMG_SIZE = [3, 32, 32]  # CIFAR-100 Image shape
N_channels = 256
Kernel_sz = 5

x_train, x_test = load_cifar100_data()

train_dataset = DatasetFolder(x_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
test_dataset = DatasetFolder(x_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

current_epoch = 0
CONTINUE_TRAINING = False

KSZ = str(Kernel_sz)+'x'+str(Kernel_sz)+'_'


In [None]:
#Test Metric (PSNR)
from skimage.metrics import peak_signal_noise_ratio as compute_pnsr


def Img_transform(test_rec):
    test_rec = test_rec.permute(0, 2, 3, 1)
    test_rec = test_rec.cpu().detach().numpy()
    test_rec = test_rec*255
    test_rec = test_rec.astype(np.uint8)
    return test_rec

def Compute_batch_PSNR(test_input, test_rec):
    psnr_i1 = np.zeros((test_input.shape[0]))
    for j in range(0, test_input.shape[0]):
        psnr_i1[j] = compute_pnsr(test_input[j, :], test_rec[j, :])
    psnr_ave = np.mean(psnr_i1)
    return psnr_ave


def Compute_IMG_PSNR(test_input, test_rec):
    psnr_i1 = np.zeros((test_input.shape[0], 1))
    for j in range(0, test_input.shape[0]):
        psnr_i1[j] = compute_pnsr(test_input[j, :], test_rec[j, :])
    return psnr_i1


In [None]:
LEARNING_RATE = 1e-4
CONTINUE_TRAINING= False
enc_out_shape = [48, IMG_SIZE[1]//4, IMG_SIZE[2]//4] #가장 첫번째가 48,  오른쪽이 feature map // ADJSCC인 경우 수정되어야함

DeepJSCC_V = ADJSCC_V(enc_out_shape, Kernel_sz, N_channels).cuda()
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(DeepJSCC_V.parameters(), lr=LEARNING_RATE)

current_epoch = 0
bestLoss = 1e3
if CONTINUE_TRAINING == True:
    #DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_201.pth.tar')['state_dict'])
    #current_epoch = 204
    current_epoch = 10
    DeepJSCC_V.load_state_dict(torch.load('/content/JSCC_models/DeepJSCC__5x5_AWGN_256_9.pth.tar')['state_dict'])


In [None]:
#DeepJSCC_V + Exisitng Attention
print('Training for DeepJSCC_V is started!')
bestLoss = 1e3

current_epoch=0
EPOCHS= 101

for epoch in range(current_epoch, EPOCHS):
    DeepJSCC_V.train()
    print('========================')
    print('lr:%.4e'%optimizer.param_groups[0]['lr'])

    # Model training
    for i, x_input in enumerate(train_loader):
        x_input = x_input.cuda()

        SNR_TRAIN = torch.randint(0, 28, (x_input.shape[0], 1)).cuda()
        CR = 0.1+0.9*torch.rand(x_input.shape[0], 1).cuda()
        x_rec = DeepJSCC_V(x_input, SNR_TRAIN, CR, CHANNEL)
        loss = criterion(x_input, x_rec)
        loss = loss.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % PRINT_RREQ == 0:
            print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item()))

    # Model Evaluation
    DeepJSCC_V.eval()
    totalLoss = 0
    with torch.no_grad():
        for i, test_input in enumerate(test_loader):
            test_input = test_input.cuda()
            SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda()
            CR = 0.1+0.9*torch.rand(test_input.shape[0], 1).cuda()
            test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL)
            totalLoss += criterion(test_rec, test_input).item() * test_input.size(0)
        averageLoss = totalLoss / (len(test_dataset))
        print('averageLoss=', averageLoss)
        if averageLoss < bestLoss:
            # Model saving
            if not os.path.exists('./JSCC_models'):
                os.makedirs('./JSCC_models')
            torch.save({'state_dict': DeepJSCC_V.state_dict(), }, './JSCC_models/DeepJSCC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_'+str(epoch)+'.pth.tar')

            print('Model saved')
            bestLoss = averageLoss

print('Training for DeepJSCC_V is finished!')


In [None]:
# DeepJSCC-V + Exisitng Attention
kernel_sz = 5
KSZ = '_'+str(kernel_sz)+'x'+str(kernel_sz)+'_'
PSNR_ave = np.zeros((10, 10))
CR=0.3 # temporary test CR

for m in range(0, 10):
    # enc_shape = [96//CR_INDEX[m], 8, 8]
    #DeepJSCC_V = ADJSCC_V(enc_out_shape, kernel_sz, N_channels).cuda()
    # DeepJSCC = nn.DataParallel(DeepJSCC)
    cr= 1/3

    for k in range(0, 10):
        print('Evaluating DeepJSCC-v with CR = '+str(cr)+' and SNR = '+str(3*k-3)+'dB')
        total_psnr = 0
        DeepJSCC_V.eval()
        with torch.no_grad():
            for i, test_input in enumerate(test_loader):
                SNR = 3*(k-1)*torch.ones((test_input.shape[0], 1)).cuda()
                CR = cr*torch.ones((test_input.shape[0], 1)).cuda()
                test_input = test_input.cuda()

                test_rec = DeepJSCC_V(test_input, SNR,CR, CHANNEL)

                test_input = Img_transform(test_input)
                test_rec  = Img_transform(test_rec)
                psnr_ave = Compute_batch_PSNR(test_input, test_rec)
                total_psnr += psnr_ave
            averagePSNR = total_psnr / i
            print('PSNR = ' + str(averagePSNR))

        PSNR_ave[m, k] = averagePSNR