In [None]:
# ============================================
# üöÄ AILO COLAB MINER v2.1 - FULL MODEL
# ============================================
# 1. Settings (‚öôÔ∏è) ‚Üí Accelerator ‚Üí GPU T4 x2 or P100
# 2. Settings ‚Üí Internet ‚Üí ON
# 3. Enter wallet below, then Run (Shift+Enter)
# ============================================

WALLET = ""  # üëà YOUR 0x... WALLET HERE

# ============================================
import subprocess, sys
try: import aiohttp
except: subprocess.run([sys.executable, '-m', 'pip', 'install', 'aiohttp', '-q'])

import torch, torch.nn as nn, numpy as np, requests, aiohttp, asyncio, time, base64, gzip, gc, json

API, SERVER, VER = "https://ailo.site/api", "https://ailo.site", "2.1.0-colab-full"
MAX_LOSS = 3.5

if not torch.cuda.is_available():
    print("‚ùå GPU not enabled! Settings ‚Üí Accelerator ‚Üí GPU"); raise SystemExit
if len(WALLET) < 40:
    print("‚ùå Enter wallet! Get one: https://ailo.site/wallet.html"); raise SystemExit

gpu_name = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"‚úÖ GPU: {gpu_name} | VRAM: {vram:.1f}GB")
print(f"‚úÖ Wallet: {WALLET[:12]}...")

# Model architecture matching init_full_model.py exactly
VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS, D_FF, MAX_SEQ = 50257, 1600, 25, 24, 6400, 512

class AILO1B(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.pos = nn.Parameter(torch.zeros(1, MAX_SEQ, D_MODEL))
        self.tf = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_FF, 0.1, batch_first=True), N_LAYERS)
        self.out = nn.Linear(D_MODEL, VOCAB_SIZE)
    def forward(self, x):
        h = self.emb(x) * 40 + self.pos[:, :x.size(1)]
        return self.out(self.tf(h))

class FullTrainer:
    def __init__(self):
        self.dev = torch.device('cuda')
        self.m = None; self.opt = None; self.steps = 0; self.acc = 0
        self.best_loss = float('inf')
        self.saved_grads = None
        self.synced = False
    
    def init(self):
        print("üß† Loading AILO-1B (899M params)...")
        gc.collect(); torch.cuda.empty_cache()
        self.m = AILO1B().to(self.dev)  # Start in FP32 for stability
        params = sum(p.numel() for p in self.m.parameters())
        print(f"   ‚úÖ Ready! {params:,} params | GPU: {torch.cuda.memory_allocated()/1e9:.1f}GB")
        # Use SGD with lower LR for stability
        self.opt = torch.optim.SGD(self.m.parameters(), lr=5e-5, momentum=0.9)
        self.crit = nn.CrossEntropyLoss()
    
    def sync_full(self):
        """Load full model checkpoint using state_dict approach."""
        try:
            print("üîÑ Syncing full model (1.8GB)...")
            
            # Get metadata
            r = requests.get(f"{SERVER}/api/model/weights", timeout=30)
            if r.status_code != 200: return self.sync_partial()
            data = r.json()
            server_loss = data.get('bestCheckpointLoss', 999)
            
            # Get layer info
            info_r = requests.get(f"{SERVER}/api/model/checkpoint/full/info", timeout=10)
            if info_r.status_code != 200:
                print("   ‚ö†Ô∏è No full checkpoint info, using partial")
                return self.sync_partial()
            
            layers = info_r.json().get('layers', [])
            if not layers:
                return self.sync_partial()
            
            print(f"   üì• Downloading full checkpoint (loss: {server_loss})...")
            cr = requests.get(f"{SERVER}/api/model/checkpoint/full", timeout=300, stream=True)
            if cr.status_code != 200:
                return self.sync_partial()
            
            weights = np.frombuffer(cr.content, dtype=np.float16)
            print(f"   üì¶ Downloaded {len(weights):,} params")
            
            # Build state dict from checkpoint
            state_dict = {}
            loaded = 0
            
            with torch.no_grad():
                for layer in layers:
                    name = layer['name']
                    offset = layer['offset']
                    size = layer['size']
                    shape = tuple(layer['shape'])
                    
                    w = weights[offset:offset+size].astype(np.float32)
                    if len(w) != size:
                        print(f"   ‚ö†Ô∏è Size mismatch for {name}")
                        continue
                    
                    # Check for NaN/Inf
                    if np.isnan(w).any() or np.isinf(w).any():
                        print(f"   ‚ö†Ô∏è Bad values in {name}, skipping")
                        continue
                    
                    tensor = torch.from_numpy(w).reshape(shape)
                    state_dict[name] = tensor
                    loaded += 1
            
            print(f"   üì¶ Built state_dict with {loaded} layers")
            
            # Load using PyTorch's load_state_dict
            result = self.m.load_state_dict(state_dict, strict=False)
            print(f"   Loaded: {len(result.unexpected_keys)} unexpected, {len(result.missing_keys)} missing")
            
            if loaded > 0:
                self.synced = True
                self.best_loss = float(server_loss) if isinstance(server_loss, (int, float)) else 3.5
                print(f"   ‚úÖ Full model loaded! Target loss: {self.best_loss:.4f}")
                return True
            
            return self.sync_partial()
            
        except Exception as e:
            print(f"   ‚ö†Ô∏è Full sync failed: {e}")
            return self.sync_partial()
    
    def sync_partial(self):
        """Fallback: load only out.weight."""
        try:
            print("   üì• Falling back to partial checkpoint...")
            r = requests.get(f"{SERVER}/api/model/weights", timeout=30)
            if r.status_code != 200: return False
            data = r.json()
            server_loss = data.get('bestCheckpointLoss', 999)
            
            cr = requests.get(f"{SERVER}/api/model/checkpoint", timeout=60)
            if cr.status_code != 200: return False
            
            w = np.frombuffer(cr.content, dtype=np.float16).astype(np.float32)
            with torch.no_grad():
                exp = self.m.out.weight.numel()
                if len(w) >= exp:
                    self.m.out.weight.data = torch.from_numpy(w[:exp]).reshape_as(self.m.out.weight).to(self.dev)
                    self.synced = True
                    self.best_loss = float(server_loss) if isinstance(server_loss, (int, float)) else 3.5
                    print(f"   ‚úÖ Partial model loaded ({exp:,} weights)")
                    return True
            return False
        except Exception as e:
            print(f"   ‚ö†Ô∏è Partial sync failed: {e}")
            return False
    
    def batch(self, w):
        try:
            r = requests.get(f"{API}/cuda/training-data", params={'batchSize': 1, 'wallet': w}, timeout=10)
            if r.ok:
                t = r.json().get('articles', [''])[0]
                if len(t) >= 33:
                    tk = [ord(c) % VOCAB_SIZE for c in t[:33]]
                    return torch.tensor([tk[:32]], dtype=torch.long), torch.tensor([tk[1:33]], dtype=torch.long)
        except: pass
        return None, None
    
    def step(self, x, y):
        if x is None: return None
        self.m.train()
        x, y = x.to(self.dev), y.to(self.dev)
        
        # No autocast - keep FP32 for stability
        out = self.m(x)
        loss = self.crit(out.view(-1, VOCAB_SIZE), y.view(-1)) / 16
        
        # Check for NaN before backprop
        if torch.isnan(loss) or torch.isinf(loss):
            return None
        
        loss.backward()
        self.acc += 1
        
        if self.acc >= 16:
            nn.utils.clip_grad_norm_(self.m.parameters(), 0.5)  # Tighter clipping
            g = self.m.out.weight.grad
            if g is not None and not torch.isnan(g).any():
                self.saved_grads = g.clone().cpu()
            self.opt.step()
            self.opt.zero_grad(set_to_none=True)
            self.acc = 0
        
        self.steps += 1
        l = loss.item() * 16
        if not np.isnan(l) and l < self.best_loss:
            self.best_loss = l
        if self.steps % 50 == 0:
            gc.collect(); torch.cuda.empty_cache()
        return l
    
    def grads(self):
        g = self.saved_grads
        if g is None: return None
        gc.collect()
        flat = g.float().flatten()
        sample = flat[::10]
        comp = gzip.compress(sample.half().numpy().tobytes(), 4)
        return base64.b64encode(comp).decode()

async def hb(w, tps):
    try:
        async with aiohttp.ClientSession() as s:
            await s.post(f"{API}/cuda/register", json={'wallet': w, 'clientVersion': VER, 'deviceInfo': {'gpu_name': f'{gpu_name}-Colab', 'vram_gb': vram, 'hashrate': tps}})
    except: pass

async def submit(w, g, s, l):
    try:
        async with aiohttp.ClientSession() as ss:
            async with ss.post(f"{API}/cuda/submit", json={'wallet': w, 'gradients': g, 'epoch': s, 'loss': l, 'gpu': f'{gpu_name}-Colab'}) as r:
                if r.status == 200: return (await r.json()).get('reward', 0)
    except: pass
    return 0

async def run():
    print("=" * 50)
    print(f"  üöÄ AILO COLAB FULL MODEL v{VER}")
    print(f"  üì¶ Training ALL 899M parameters!")
    print("=" * 50)
    
    await hb(WALLET, 0)
    t = FullTrainer(); t.init()
    
    if not t.sync_full():
        print("\n‚ö†Ô∏è Could not sync with global model")
    
    print("\n‚õèÔ∏è MINING STARTED!")
    print(f"üìä Dashboard: https://ailo.site/dashboard.html?wallet={WALLET[:12]}\n")
    
    rew, last_s, last_h, tps = 0.0, time.time(), time.time(), 0.0
    nan_count = 0
    
    try:
        while True:
            t0 = time.time()
            x, y = t.batch(WALLET)
            if x is None:
                await asyncio.sleep(0.5)
                continue
            
            loss = t.step(x, y)
            if loss is None:
                nan_count += 1
                if nan_count > 10:
                    print("‚ö†Ô∏è Too many NaN, re-syncing...")
                    t.sync_partial()  # Fallback to partial
                    nan_count = 0
                continue
            nan_count = 0
            
            tps = 32 / (time.time() - t0 + 0.001)
            
            if time.time() - last_h >= 10: await hb(WALLET, tps); last_h = time.time()
            if t.steps % 25 == 0: print(f"Step {t.steps} | Loss: {loss:.4f} | Best: {t.best_loss:.4f} | {tps:.0f} tok/s | üí∞ {rew:.4f}")
            
            if time.time() - last_s >= 300:
                gc.collect(); torch.cuda.empty_cache()
                g = t.grads()
                if g and t.synced and t.best_loss < MAX_LOSS and not np.isnan(t.best_loss):
                    print(f"\nüì§ Submitting gradients (loss: {t.best_loss:.4f})...")
                    r = await submit(WALLET, g, t.steps, t.best_loss)
                    rew += r; print(f"üí∞ +{r:.4f} (Total: {rew:.4f})")
                t.sync_partial()  # Re-sync with latest
                last_s = time.time()
                print()
    except KeyboardInterrupt:
        print(f"\n‚èπÔ∏è Stopped. Total: {rew:.4f} ALC")

await run()