In [1]:
import scipy.io

# Load the .mat file
file_path = '/home/rithvik/iitm2/csi-prediction/EV_Rank_1_52_RBs_50_UEs_1000_snaps.mat'
data = scipy.io.loadmat(file_path)

# Extract the relevant data
data = data['EV_re_im_split']

# Check the shape and structure of the extracted data
data = data[:25]
print(data.shape)


(25, 1000, 832)


In [2]:
type(data)

numpy.ndarray

In [3]:
import numpy as np

input_timesteps = 5
output_timesteps = 1

# Prepare the input and output sequences
X = []
y = []

for i in range(data.shape[0]):  # iterate over samples
    for j in range(data.shape[1] - input_timesteps):  # iterate over timesteps
        X.append(data[i, j:j+input_timesteps, :])
        y.append(data[i, j+input_timesteps, :])

X = np.array(X)
y = np.array(y)


In [4]:
import torch

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)


In [5]:
import torch
from iTransformer import iTransformer

# Assuming X_tensor and y_tensor are already defined

# Initialize the model
model = iTransformer(
    num_variates=832,
    lookback_len=5,
    depth=6,
    dim=512,
    pred_length=1,
    dim_head=32,
    heads=8,
    attn_dropout=0.1,
    ff_mult=4,
    ff_dropout=0.1,
    num_mem_tokens=4,
    use_reversible_instance_norm=True,
    reversible_instance_norm_affine=True,
    flash_attn=True
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

epochs = 100
batch_size = 32
num_batches = len(X_tensor) // batch_size

for epoch in range(epochs):
    model.train()
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        X_batch = X_tensor[start_idx:end_idx]
        y_batch = y_tensor[start_idx:end_idx]

        optimizer.zero_grad()
        
        # Forward pass
        output = model(X_batch, targets=y_batch)
        
        # Debugging prints
        print(f'Batch {batch_idx + 1}/{num_batches}')
        print(f'X_batch shape: {X_batch.shape}')
        print(f'y_batch shape: {y_batch.shape}')
        
        if isinstance(output, dict):
            for key, value in output.items():
                print(f'Pred shape for key {key}: {value.shape}')
        else:
            print(f'Output shape: {output.shape}')

        # Loss calculation
        loss = output
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')

print('Training complete.')


Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


AssertionError: 