<a href="https://colab.research.google.com/github/perfect7613/fnetimplementation/blob/main/Fnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch_geometric
!pip install pywavelets
!pip install torchtext

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Collecting pywavelets
  Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m56.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pywavelets
Su

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
import pywt
import numpy as np
from torch.utils.data import Dataset, DataLoader as TorchDataLoader

In [3]:
class LearnableWaveletLayer(nn.Module):
    def __init__(self, input_dim, num_wavelets, wavelet_name='morl'):
        super().__init__()
        self.input_dim = input_dim  # Now represents character signal dimension (e.g., 1 for scalar values)
        self.num_wavelets = num_wavelets
        self.wavelet_name = wavelet_name
        self.scales = nn.Parameter(torch.rand(num_wavelets) + 1.0)
        self.translations = nn.Parameter(torch.randn(num_wavelets))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        wavelet_coeffs = []
        for i in range(self.num_wavelets):
            coeffs_batch = []
            for b in range(batch_size):
                coeffs_embed = []
                for e in range(self.input_dim):
                    data = x[b, :, e].detach().cpu().numpy()
                    scale = torch.abs(self.scales[i]).cpu().item() + 1e-6
                    if scale < 1: scale = 1
                    coeffs, freqs = pywt.cwt(data, scale, self.wavelet_name)
                    coeffs_embed.append(torch.tensor(coeffs, dtype=torch.float32, device=x.device))
                coeffs_batch.append(torch.stack(coeffs_embed))
            wavelet_coeffs.append(torch.stack(coeffs_batch))
        wavelet_coeffs = torch.stack(wavelet_coeffs).permute(1, 2, 0, 3, 4)
        wavelet_coeffs = wavelet_coeffs.reshape(batch_size, seq_len, self.input_dim * self.num_wavelets)
        return wavelet_coeffs

    def inverse_transform(self, coeffs):
        batch_size, seq_len, hidden_dim = coeffs.shape
        reconstructed = torch.zeros(batch_size, seq_len, hidden_dim, device=coeffs.device)
        temp_coeffs = torch.zeros(batch_size, seq_len, hidden_dim * self.num_wavelets, device=coeffs.device)
        temp_coeffs[:, :, :hidden_dim] = coeffs
        temp_coeffs = temp_coeffs.reshape(batch_size, hidden_dim, self.num_wavelets, seq_len).permute(2, 0, 1, 3)
        for i in range(self.num_wavelets):
            current_coeffs = temp_coeffs[i]
            for b in range(batch_size):
                for e in range(hidden_dim):
                    data = current_coeffs[b, e].detach().cpu().numpy()
                    rec_signal = pywt.waverec([data], self.wavelet_name, mode='per')
                    rec_signal = np.array(rec_signal)
                    if len(rec_signal) > seq_len: rec_signal = rec_signal[:seq_len]
                    elif len(rec_signal) < seq_len: rec_signal = np.pad(rec_signal, (0, seq_len - len(rec_signal)))
                    reconstructed[b, :, e] = torch.tensor(rec_signal, dtype=torch.float32, device=coeffs.device)
        return reconstructed

# --- Modified URM without Tokenization ---

class URM(nn.Module):
    def __init__(self, char_set_size, hidden_dim, num_layers, num_wavelets, dropout=0.1):
        super().__init__()
        self.num_wavelets = num_wavelets
        self.hidden_dim = hidden_dim

        # No embedding layer; input_dim = 1 (character signal)
        self.wavelet_layer = LearnableWaveletLayer(input_dim=1, num_wavelets=num_wavelets)

        # Convolutional modulation: input channels = num_wavelets (since input_dim=1)
        self.conv_mod = nn.Conv1d(num_wavelets, hidden_dim, kernel_size=3, padding=1)

        # Graph Convolutional Network
        self.gcn_layers = nn.ModuleList([GCNConv(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.gcn_dropout = nn.Dropout(dropout)

        # Feedforward Network
        self.ffn1 = nn.Linear(hidden_dim, hidden_dim * 4)
        self.ffn2 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.ffn_dropout = nn.Dropout(dropout)

        # Output layer predicts characters
        self.output_layer = nn.Linear(hidden_dim, char_set_size)

    def forward(self, input_signal, dependency_graphs):
        # input_signal: [batch_size, seq_len, 1] (raw character values)

        # 1. Learnable Wavelet Transform
        wavelet_coeffs = self.wavelet_layer(input_signal)  # [batch_size, seq_len, num_wavelets]

        # 2. Convolutional Modulation
        modulated_coeffs = self.conv_mod(wavelet_coeffs.transpose(1, 2)).transpose(1, 2)
        modulated_coeffs = F.gelu(modulated_coeffs)  # [batch_size, seq_len, hidden_dim]

        # 3. Inverse Wavelet Transform
        reconstructed = self.wavelet_layer.inverse_transform(modulated_coeffs)

        # 4. Prepare for GCN
        graph_data = [Data(x=reconstructed[i], edge_index=adj_matrix.nonzero().t())
                      for i, adj_matrix in enumerate(dependency_graphs)]
        graph_loader = GeometricDataLoader(graph_data, batch_size=len(graph_data))
        batch = next(iter(graph_loader))

        # 5. Graph Convolutional Network
        x = batch.x
        edge_index = batch.edge_index
        for gcn_layer in self.gcn_layers:
            x = gcn_layer(x, edge_index)
            x = F.gelu(x)
            x = self.gcn_dropout(x)
        x = x.reshape(len(graph_data), -1, self.hidden_dim)

        # 6. Feedforward Network
        x = F.gelu(self.ffn1(x))
        x = self.ffn_dropout(x)
        x = self.ffn2(x)

        # 7. Output Layer (predict characters)
        logits = self.output_layer(x)
        return logits

In [4]:
!wget -nc https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-02-23 07:33:00--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-02-23 07:33:00 (27.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [12]:
with open('input.txt', 'r') as f:
    text = f.read()

# Create character set
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
char_set_size = len(chars)

# Hyperparameters
hidden_dim = 64
num_layers = 3
num_wavelets = 4
batch_size = 4
seq_len = 10
num_epochs = 600
learning_rate = 0.001

In [13]:
def text_to_signal(text, seq_len, char_to_idx):
    signal = [char_to_idx[ch] for ch in text]
    sequences = []
    labels = []
    for i in range(0, len(signal) - seq_len, seq_len):
        seq = signal[i:i + seq_len]
        label = signal[i + 1:i + seq_len + 1]  # Next-character prediction
        sequences.append(seq)
        labels.append(label)
    return sequences, labels

sequences, labels = text_to_signal(text, seq_len, char_to_idx)
sequences = torch.tensor(sequences[:batch_size * 10], dtype=torch.float32).unsqueeze(-1)  # [N, seq_len, 1]
labels = torch.tensor(labels[:batch_size * 10], dtype=torch.long)

# Generate simple dependency graphs (linear chain)
def get_dependency_graph(batch_size, seq_len):
    dependency_graphs = []
    for _ in range(batch_size):
        adj_matrix = torch.zeros((seq_len, seq_len))
        for i in range(seq_len - 1):
            adj_matrix[i, i + 1] = 1
            adj_matrix[i + 1, i] = 1
        dependency_graphs.append(adj_matrix)
    return dependency_graphs

dependency_graphs_list = get_dependency_graph(len(sequences), seq_len)

In [14]:
class TextDataset(Dataset):
    def __init__(self, input_signal, dependency_graphs, labels):
        self.input_signal = input_signal
        self.dependency_graphs = dependency_graphs
        self.labels = labels

    def __len__(self):
        return len(self.input_signal)

    def __getitem__(self, idx):
        return {
            'input_signal': self.input_signal[idx],
            'dependency_graph': self.dependency_graphs[idx],
            'label': self.labels[idx]
        }

dataset = TextDataset(sequences, dependency_graphs_list, labels)
dataloader = TorchDataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- Model, Loss, and Optimizer ---

model = URM(char_set_size, hidden_dim, num_layers, num_wavelets)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# --- Training Loop ---

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_signal_batch = batch['input_signal']
        dependency_graphs_batch = batch['dependency_graph']
        labels_batch = batch['label']
        logits = model(input_signal_batch, dependency_graphs_batch)
        loss = criterion(logits.view(-1, char_set_size), labels_batch.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("Training Complete!")

Epoch 1/600, Loss: 4.0000
Epoch 2/600, Loss: 3.5341
Epoch 3/600, Loss: 3.3350
Epoch 4/600, Loss: 3.2742
Epoch 5/600, Loss: 3.2174
Epoch 6/600, Loss: 3.1910
Epoch 7/600, Loss: 3.1621
Epoch 8/600, Loss: 3.1312
Epoch 9/600, Loss: 3.1063
Epoch 10/600, Loss: 3.0826
Epoch 11/600, Loss: 3.0507
Epoch 12/600, Loss: 3.0138
Epoch 13/600, Loss: 2.9796
Epoch 14/600, Loss: 2.9516
Epoch 15/600, Loss: 2.9508
Epoch 16/600, Loss: 2.9288
Epoch 17/600, Loss: 2.8720
Epoch 18/600, Loss: 2.8248
Epoch 19/600, Loss: 2.8017
Epoch 20/600, Loss: 2.7901
Epoch 21/600, Loss: 2.7767
Epoch 22/600, Loss: 2.7615
Epoch 23/600, Loss: 2.7226
Epoch 24/600, Loss: 2.7257
Epoch 25/600, Loss: 2.6756
Epoch 26/600, Loss: 2.6058
Epoch 27/600, Loss: 2.5834
Epoch 28/600, Loss: 2.5717
Epoch 29/600, Loss: 2.5462
Epoch 30/600, Loss: 2.5245
Epoch 31/600, Loss: 2.4781
Epoch 32/600, Loss: 2.4439
Epoch 33/600, Loss: 2.4781
Epoch 34/600, Loss: 2.4350
Epoch 35/600, Loss: 2.3962
Epoch 36/600, Loss: 2.3750
Epoch 37/600, Loss: 2.3329
Epoch 38/6

In [19]:
model.eval()
with torch.no_grad():
    sample_input = sequences[:1]  # [1, seq_len, 1]
    sample_graph = dependency_graphs_list[:1]
    logits = model(sample_input, sample_graph)
    predictions = torch.argmax(logits, dim=-1)
    predicted_chars = [idx_to_char[idx.item()] for idx in predictions[0]]
    input_chars = [idx_to_char[int(idx.item())] for idx in sample_input[0, :, 0]]
    print("\nSample Input Characters:", ''.join(input_chars))
    print("Predicted Next Characters:", ''.join(predicted_chars))


Sample Input Characters: First Citi
Predicted Next Characters: irst Citiz
