In [3]:
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append("/workspaces/MambaLinearCode")
os.chdir("/workspaces/MambaLinearCode")


In [4]:
from configuration import Code, Config
from dataset import get_generator_and_parity
import torch
import os
import logging

def code_from_hint(hint,):
    code_files = os.listdir(CODES_PATH)
    code_files = [f for f in code_files if hint in f][0]
    print(code_files)
    code_n = int(code_files.split('_')[1][1:])
    code_k = int(code_files.split('_')[-1][1:].split('.')[0])
    code_type = code_files.split('_')[0]
    code = Code(code_n, code_k, code_type)
    return code

OUTPUT_PATH = ".output/"
CODES_PATH = "codes/"
example_code = code_from_hint("LDPC_N49_K24")
G,H = get_generator_and_parity(example_code, standard_form=True)
example_code.generator_matrix = torch.from_numpy(G).transpose(0,1).long()
example_code.pc_matrix = torch.from_numpy(H).long()

os.makedirs(OUTPUT_PATH, exist_ok=True)
config = Config(
    code=example_code,
    d_model=32,
    d_state=64,
    path=OUTPUT_PATH,
    N_dec=8,
    warmup_lr=1.0e-4,
    lr=1.0e-4,
    epochs=1000
)

handlers = [
        logging.FileHandler(os.path.join(OUTPUT_PATH, 'logging.txt')),
        logging.StreamHandler()
    ]
logging.basicConfig(level=logging.INFO, format='%(message)s',
                    handlers=handlers)

LDPC_N49_K24.alist


In [5]:
from mamba_ssm import Mamba
from dataset import EbN0_to_std, ECC_Dataset, train, test, sign_to_bin
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from torch.nn import ModuleList, LayerNorm
import copy

device = "cuda"

def clones(module, N):
    return ModuleList([copy.deepcopy(module) for _ in range(N)])

class EncoderLayer(torch.nn.Module):
    def __init__(self, config: Config, length) -> None:
        super().__init__()
        self.mamba = Mamba(
            d_model=config.d_model,
            d_state=config.d_state
        )
        self.norm = LayerNorm((length, config.d_model))
    
    def forward(self, x):
        o = self.mamba.forward(x)
        return self.norm(F.tanh(o))

class ECCM(torch.nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.n = config.code.n
        self.syndrom_length = config.code.pc_matrix.size(0)
        self.src_embed = torch.nn.Parameter(torch.ones(
            (self.n + self.syndrom_length, config.d_model)))
        self.resize_output_dim = torch.nn.Linear(config.d_model, 1)
        self.resize_output_length = torch.nn.Linear(self.n + self.syndrom_length, self.n)
        self.norm_output = LayerNorm((self.n,))
        
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)
        
        self.mamba: ModuleList = clones(EncoderLayer(config, (self.n + self.syndrom_length)), config.N_dec)
    
    def forward(self, magnitude, syndrome):
        emb = torch.cat([magnitude, syndrome], -1).unsqueeze(-1)
        out: torch.Tensor = self.src_embed.unsqueeze(0) * emb
        for sublayer in self.mamba:
            out: torch.Tensor = sublayer.forward(out)
        
        out: torch.Tensor = self.resize_output_length(out.swapaxes(-2,-1))
        out: torch.Tensor = self.resize_output_dim(out.swapaxes(-2,-1))
        out: torch.Tensor = out.squeeze(-1)
        return self.norm_output(F.tanh(out))

    def loss(self, z_pred, z2, y):
        loss = F.binary_cross_entropy_with_logits(
            z_pred, sign_to_bin(torch.sign(z2)))
        x_pred = sign_to_bin(torch.sign(-z_pred * torch.sign(y)))
        return loss, x_pred


  from .autonotebook import tqdm as notebook_tqdm


In [4]:


model = ECCM(config=config).to("cuda")

def train_model(args: Config, model: torch.nn.Module):
    code = args.code
    initial_lr = args.warmup_lr
    device = "cuda" if torch.cuda.is_available() else "cpu"
    optimizer = Adam(model.parameters(), lr=args.warmup_lr)

    model.load_state_dict(torch.load(os.path.join(config.path, 'best_model')))
    # optimizer.load_state_dict(torch.load(os.path.join(config.path, 'optimizer_checkpoint')))
    

    #################################
    EbNo_range_test = range(4, 7)
    EbNo_range_train = range(2, 8)
    std_train = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_train]
    std_test = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_test]
    train_dataloader = DataLoader(ECC_Dataset(code, std_train, len=args.batch_size * 1000, zero_cw=True), batch_size=int(args.batch_size),
                                  shuffle=True, num_workers=args.workers)
    test_dataloader_list = [DataLoader(ECC_Dataset(code, [std_test[ii]], len=int(args.test_batch_size), zero_cw=False),
                                       batch_size=int(args.test_batch_size), shuffle=False, num_workers=args.workers) for ii in range(len(std_test))]
    #################################

    best_loss = float('inf')
    # for epoch in range(1,3):
    #     loss, ber, fer = train(model, device, train_dataloader, optimizer,
    #                            epoch, LR=initial_lr, config=args)
    #     if loss < best_loss:
    #         best_loss = loss
    #         torch.save(model.state_dict(), os.path.join(args.path, 'best_model'))
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr
    
    scheduler = CosineAnnealingLR(optimizer, T_max=1000, eta_min=args.eta_min)
    # scheduler.load_state_dict(torch.load(os.path.join(config.path, 'scheduler_checkpoint')))

    for epoch in range(1, args.epochs + 1):
        loss, ber, fer = train(model, device, train_dataloader, optimizer,
                               epoch, LR=scheduler.get_last_lr()[0], config=args)
        scheduler.step()
        if loss < best_loss:
            best_loss = loss
            torch.save(model.state_dict(), os.path.join(args.path, 'best_model'))
            torch.save(optimizer.state_dict(), os.path.join(args.path, 'optimizer_checkpoint'))
            torch.save(scheduler.state_dict(), os.path.join(args.path, 'scheduler_checkpoint'))

        # if epoch % 200 == 0:
        #     test(model, device, test_dataloader_list, EbNo_range_test)
    return model

train_model(config, model)

  model.load_state_dict(torch.load(os.path.join(config.path, 'best_model')))
Training: 100%|█████████▉| 999/1000 [02:24<00:00,  6.89it/s]Training epoch 1, Batch 1000/1000: LR=1.00e-04, Loss=5.76e-02 BER=2.21e-02 FER=2.89e-01
Training: 100%|██████████| 1000/1000 [02:24<00:00,  6.90it/s]
Epoch 1 Train Time 144.91073369979858s

Training: 100%|█████████▉| 999/1000 [02:23<00:00,  7.30it/s]Training epoch 2, Batch 1000/1000: LR=1.00e-04, Loss=5.68e-02 BER=2.17e-02 FER=2.81e-01
Training: 100%|██████████| 1000/1000 [02:23<00:00,  6.95it/s]
Epoch 2 Train Time 143.88948512077332s

Training: 100%|█████████▉| 999/1000 [02:19<00:00,  7.21it/s]Training epoch 3, Batch 1000/1000: LR=1.00e-04, Loss=5.65e-02 BER=2.17e-02 FER=2.82e-01
Training: 100%|██████████| 1000/1000 [02:19<00:00,  7.18it/s]
Epoch 3 Train Time 139.23549699783325s

Training: 100%|█████████▉| 999/1000 [02:21<00:00,  6.84it/s]Training epoch 4, Batch 1000/1000: LR=1.00e-04, Loss=5.60e-02 BER=2.15e-02 FER=2.79e-01
Training: 100%|██████████

In [6]:

model = ECCM(config=config)
model.load_state_dict(torch.load(os.path.join(config.path, 'best_model')))
model = model.to("cuda")

  model.load_state_dict(torch.load(os.path.join(config.path, 'best_model')))


In [32]:
from dataset import bin_to_sign

code = config.code
EbNo_range_train = [5]
std_train = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_train]
m,x,z,y,mag,syn = ECC_Dataset(code, std_train, len=config.batch_size * 1000, zero_cw=False)[0]
z_mul = (y * bin_to_sign(x))
if len(z_mul.shape) < 2:
    z_mul = z_mul.unsqueeze(0)
z_pred = model(mag.to('cuda'), syn.to('cuda'))
print(z_pred.shape, z_mul.shape)
mag, syn, z_pred, z_mul, model.loss(-z_pred, z_mul.to('cuda'), y.to('cuda'))

torch.Size([1, 49]) torch.Size([1, 49])


(tensor([1.1816, 1.4278, 1.2788, 1.9569, 0.7318, 0.8519, 1.9273, 1.4133, 1.3416,
         1.2742, 0.8815, 1.4579, 0.9340, 1.2536, 2.0998, 0.4732, 1.2108, 1.6872,
         0.1434, 1.2015, 0.3701, 0.8556, 0.6629, 0.5573, 0.9793, 1.4336, 2.2401,
         0.4965, 1.2220, 0.4444, 0.1743, 1.7205, 0.9718, 0.6903, 0.5496, 1.5444,
         1.5465, 1.3216, 0.9732, 1.2767, 0.9685, 1.0407, 0.3150, 0.3329, 0.4337,
         1.2422, 0.1428, 1.3265, 1.1648]),
 tensor([ 1.,  1., -1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1., -1.,  1., -1.,  1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1.]),
 tensor([[ 7.4994,  7.1261,  7.3770,  7.3142,  7.3708,  7.4391,  7.2971,  7.3617,
           7.3545,  7.4357,  7.3122,  7.3635,  7.4044,  7.3483,  7.2708,  7.1398,
           7.0714,  7.4451, -5.5284,  7.4245,  7.4163,  7.3740,  7.4801,  6.5448,
           7.4942,  6.8799,  7.3551,  7.4026,  7.2035,  7.3659, -8.4215,  7.3658,
           6.9867,  7.0834,  7.4865,  7.3118,  7.1665,  7.3273, 

In [None]:
# Ideas:
# Bi-directional
# Load and output