|
| 1 | +## Intro |
| 2 | + |
| 3 | +This readme will have a subsection for every example *.py file. |
| 4 | + |
| 5 | +Please follow the instructions in [README.md](../README.md) to install torch_xla2, |
| 6 | +then install requirements for all of the examples with |
| 7 | + |
| 8 | +```bash |
| 9 | +pip install -r requirements.txt |
| 10 | +``` |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | +## basic_training.py |
| 15 | + |
| 16 | +This file constructed by first copy & paste code fragments from this pytorch training tutorial: |
| 17 | +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html |
| 18 | + |
| 19 | +Then adding few lines of code that serves the purpose of moving `torch.Tensor` into |
| 20 | +`XLA devices`. |
| 21 | + |
| 22 | +Example: |
| 23 | + |
| 24 | +```python |
| 25 | +state_dict = pytree.tree_map_only(torch.Tensor, |
| 26 | + torch_xla2.tensor.move_to_device, state_dict) |
| 27 | +``` |
| 28 | + |
| 29 | +This fragment moves the state_dict to XLA devices; then the state_dict is passed |
| 30 | +back to model via `load_state_dict`. |
| 31 | + |
| 32 | +Then, you can train the model. This shows what is minimum to train a model on XLA |
| 33 | +devices. The perf is not as good because we didn't use `jax.jit`, this is intentional |
| 34 | +as it is meant to showcase the minimum code change. |
| 35 | + |
| 36 | +Example run: |
| 37 | +```bash |
| 38 | +(xla2) hanq-macbookpro:examples hanq$ python basic_training.py |
| 39 | +Training set has 60000 instances |
| 40 | +Validation set has 10000 instances |
| 41 | +Bag Dress Sneaker T-shirt/top |
| 42 | +tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035, |
| 43 | + 0.2279], |
| 44 | + [0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561, |
| 45 | + 0.5985], |
| 46 | + [0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504, |
| 47 | + 0.8738], |
| 48 | + [0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086, |
| 49 | + 0.5709]]) |
| 50 | +tensor([1, 5, 3, 7]) |
| 51 | +Total loss for this batch: 2.325265645980835 |
| 52 | +EPOCH 1: |
| 53 | + batch 1000 loss: 1.041275198560208 |
| 54 | + batch 2000 loss: 0.6450189483696595 |
| 55 | + batch 3000 loss: 0.5793989677671343 |
| 56 | + batch 4000 loss: 0.5170258888280951 |
| 57 | + batch 5000 loss: 0.4920090722264722 |
| 58 | + batch 6000 loss: 0.48910293977567926 |
| 59 | + batch 7000 loss: 0.48058812761632724 |
| 60 | + batch 8000 loss: 0.47159107415075413 |
| 61 | + batch 9000 loss: 0.4712311488997657 |
| 62 | + batch 10000 loss: 0.4675815168160479 |
| 63 | + batch 11000 loss: 0.43210567891132085 |
| 64 | + batch 12000 loss: 0.445208148030797 |
| 65 | + batch 13000 loss: 0.4119230824254337 |
| 66 | + batch 14000 loss: 0.4190662656680215 |
| 67 | + batch 15000 loss: 0.4094535468676477 |
| 68 | +LOSS train 0.4094535468676477 valid XLA |
| 69 | +``` |
| 70 | + |
| 71 | +## basic_training_jax.py |
| 72 | + |
| 73 | +This file constructed by first copy & paste code fragments from this pytorch training tutorial: |
| 74 | +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html |
| 75 | + |
| 76 | +Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for |
| 77 | +gradient instead of `torch.Tensor.backward()`. |
| 78 | + |
| 79 | +Then, you can train the model using jax ecosystem's training loop. This is meant to |
| 80 | +showcase how easy is to integrate with Jax. |
| 81 | + |
| 82 | +Example run: |
| 83 | +```bash |
| 84 | +(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py |
| 85 | +Training set has 60000 instances |
| 86 | +Validation set has 10000 instances |
| 87 | +Pullover Ankle Boot Pullover Ankle Boot |
| 88 | +tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872, |
| 89 | + 0.1854], |
| 90 | + [0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010, |
| 91 | + 0.2766], |
| 92 | + [0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607, |
| 93 | + 0.6450], |
| 94 | + [0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377, |
| 95 | + 0.7453]]) |
| 96 | +tensor([1, 5, 3, 7]) |
| 97 | +Total loss for this batch: 2.4054245948791504 |
| 98 | +EPOCH 1: |
| 99 | + batch 1000 loss: 1.0705260595591972 |
| 100 | + batch 2000 loss: 1.0997755021179327 |
| 101 | + batch 3000 loss: 1.0186579653513108 |
| 102 | + batch 4000 loss: 0.9090727646966116 |
| 103 | + batch 5000 loss: 0.8309370622411024 |
| 104 | + batch 6000 loss: 0.8702225417760783 |
| 105 | + batch 7000 loss: 0.8750176187023462 |
| 106 | + batch 8000 loss: 0.9652624803795453 |
| 107 | + batch 9000 loss: 0.8688667197711766 |
| 108 | + batch 10000 loss: 0.8021814124770199 |
| 109 | + batch 11000 loss: 0.8000540231048071 |
| 110 | + batch 12000 loss: 0.9150884484921057 |
| 111 | + batch 13000 loss: 0.819690621060171 |
| 112 | + batch 14000 loss: 0.8569030471532278 |
| 113 | + batch 15000 loss: 0.8740896808278603 |
| 114 | +LOSS train 0.8740896808278603 valid 2.3132264614105225 |
| 115 | +``` |
0 commit comments